Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Handle refusal in OpenAI's Structured Outputs API #533

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 4 additions & 49 deletions packages/langchain_openai/lib/src/chat_models/chat_openai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,10 @@ class ChatOpenAI extends BaseChatModel<ChatOpenAIOptions> {
final ChatOpenAIOptions? options,
}) async {
final completion = await _client.createChatCompletion(
request: _createChatCompletionRequest(
request: createChatCompletionRequest(
input.toChatMessages(),
options: options,
defaultOptions: defaultOptions,
),
);
return completion.toChatResult(completion.id ?? _uuid.v4());
Expand All @@ -263,9 +264,10 @@ class ChatOpenAI extends BaseChatModel<ChatOpenAIOptions> {
}) {
return _client
.createChatCompletionStream(
request: _createChatCompletionRequest(
request: createChatCompletionRequest(
input.toChatMessages(),
options: options,
defaultOptions: defaultOptions,
stream: true,
),
)
Expand All @@ -275,53 +277,6 @@ class ChatOpenAI extends BaseChatModel<ChatOpenAIOptions> {
);
}

/// Creates a [CreateChatCompletionRequest] from the given input.
CreateChatCompletionRequest _createChatCompletionRequest(
final List<ChatMessage> messages, {
final ChatOpenAIOptions? options,
final bool stream = false,
}) {
final messagesDtos = messages.toChatCompletionMessages();
final toolsDtos =
(options?.tools ?? defaultOptions.tools)?.toChatCompletionTool();
final toolChoice = (options?.toolChoice ?? defaultOptions.toolChoice)
?.toChatCompletionToolChoice();
final responseFormatDto =
(options?.responseFormat ?? defaultOptions.responseFormat)
?.toChatCompletionResponseFormat();
final serviceTierDto = (options?.serviceTier ?? defaultOptions.serviceTier)
.toCreateChatCompletionRequestServiceTier();

return CreateChatCompletionRequest(
model: ChatCompletionModel.modelId(
options?.model ?? defaultOptions.model ?? defaultModel,
),
messages: messagesDtos,
tools: toolsDtos,
toolChoice: toolChoice,
frequencyPenalty:
options?.frequencyPenalty ?? defaultOptions.frequencyPenalty,
logitBias: options?.logitBias ?? defaultOptions.logitBias,
maxTokens: options?.maxTokens ?? defaultOptions.maxTokens,
n: options?.n ?? defaultOptions.n,
presencePenalty:
options?.presencePenalty ?? defaultOptions.presencePenalty,
responseFormat: responseFormatDto,
seed: options?.seed ?? defaultOptions.seed,
stop: (options?.stop ?? defaultOptions.stop) != null
? ChatCompletionStop.listString(options?.stop ?? defaultOptions.stop!)
: null,
temperature: options?.temperature ?? defaultOptions.temperature,
topP: options?.topP ?? defaultOptions.topP,
parallelToolCalls:
options?.parallelToolCalls ?? defaultOptions.parallelToolCalls,
serviceTier: serviceTierDto,
user: options?.user ?? defaultOptions.user,
streamOptions:
stream ? const ChatCompletionStreamOptions(includeUsage: true) : null,
);
}

/// Tokenizes the given prompt using tiktoken with the encoding used by the
/// [model]. If an encoding model is specified in [encoding] field, that
/// encoding is used instead.
Expand Down
128 changes: 103 additions & 25 deletions packages/langchain_openai/lib/src/chat_models/mappers.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,93 @@ import 'package:langchain_core/language_models.dart';
import 'package:langchain_core/tools.dart';
import 'package:openai_dart/openai_dart.dart';

import 'chat_openai.dart';
import 'types.dart';

/// Creates a [CreateChatCompletionRequest] from the given input.
CreateChatCompletionRequest createChatCompletionRequest(
final List<ChatMessage> messages, {
required final ChatOpenAIOptions? options,
required final ChatOpenAIOptions defaultOptions,
final bool stream = false,
}) {
final messagesDtos = messages.toChatCompletionMessages();
final toolsDtos =
(options?.tools ?? defaultOptions.tools)?.toChatCompletionTool();
final toolChoice = (options?.toolChoice ?? defaultOptions.toolChoice)
?.toChatCompletionToolChoice();
final responseFormatDto =
(options?.responseFormat ?? defaultOptions.responseFormat)
?.toChatCompletionResponseFormat();
final serviceTierDto = (options?.serviceTier ?? defaultOptions.serviceTier)
.toCreateChatCompletionRequestServiceTier();

return CreateChatCompletionRequest(
model: ChatCompletionModel.modelId(
options?.model ?? defaultOptions.model ?? ChatOpenAI.defaultModel,
),
messages: messagesDtos,
tools: toolsDtos,
toolChoice: toolChoice,
frequencyPenalty:
options?.frequencyPenalty ?? defaultOptions.frequencyPenalty,
logitBias: options?.logitBias ?? defaultOptions.logitBias,
maxTokens: options?.maxTokens ?? defaultOptions.maxTokens,
n: options?.n ?? defaultOptions.n,
presencePenalty: options?.presencePenalty ?? defaultOptions.presencePenalty,
responseFormat: responseFormatDto,
seed: options?.seed ?? defaultOptions.seed,
stop: (options?.stop ?? defaultOptions.stop) != null
? ChatCompletionStop.listString(options?.stop ?? defaultOptions.stop!)
: null,
temperature: options?.temperature ?? defaultOptions.temperature,
topP: options?.topP ?? defaultOptions.topP,
parallelToolCalls:
options?.parallelToolCalls ?? defaultOptions.parallelToolCalls,
serviceTier: serviceTierDto,
user: options?.user ?? defaultOptions.user,
streamOptions:
stream ? const ChatCompletionStreamOptions(includeUsage: true) : null,
);
}

extension ChatMessageListMapper on List<ChatMessage> {
List<ChatCompletionMessage> toChatCompletionMessages() {
return map(_mapMessage).toList(growable: false);
}

ChatCompletionMessage _mapMessage(final ChatMessage msg) {
return switch (msg) {
final SystemChatMessage systemChatMessage => ChatCompletionMessage.system(
content: systemChatMessage.content,
),
final HumanChatMessage humanChatMessage => ChatCompletionMessage.user(
content: switch (humanChatMessage.content) {
final ChatMessageContentText c => _mapMessageContentString(c),
final ChatMessageContentImage c =>
ChatCompletionUserMessageContent.parts(
[_mapMessageContentPartImage(c)],
),
final ChatMessageContentMultiModal c => _mapMessageContentPart(c),
},
),
final AIChatMessage aiChatMessage => ChatCompletionMessage.assistant(
content: aiChatMessage.content,
toolCalls: aiChatMessage.toolCalls.isNotEmpty
? aiChatMessage.toolCalls
.map(_mapMessageToolCall)
.toList(growable: false)
: null,
),
final ToolChatMessage toolChatMessage => ChatCompletionMessage.tool(
toolCallId: toolChatMessage.toolCallId,
content: toolChatMessage.content,
),
final SystemChatMessage msg => _mapSystemMessage(msg),
final HumanChatMessage msg => _mapHumanMessage(msg),
final AIChatMessage msg => _mapAIMessage(msg),
final ToolChatMessage msg => _mapToolMessage(msg),
CustomChatMessage() =>
throw UnsupportedError('OpenAI does not support custom messages'),
};
}

ChatCompletionMessage _mapSystemMessage(
final SystemChatMessage systemChatMessage,
) {
return ChatCompletionMessage.system(content: systemChatMessage.content);
}

ChatCompletionMessage _mapHumanMessage(
final HumanChatMessage humanChatMessage,
) {
return ChatCompletionMessage.user(
content: switch (humanChatMessage.content) {
final ChatMessageContentText c => _mapMessageContentString(c),
final ChatMessageContentImage c =>
ChatCompletionUserMessageContent.parts(
[_mapMessageContentPartImage(c)],
),
final ChatMessageContentMultiModal c => _mapMessageContentPart(c),
},
);
}

ChatCompletionUserMessageContentString _mapMessageContentString(
final ChatMessageContentText c,
) {
Expand Down Expand Up @@ -105,6 +153,17 @@ extension ChatMessageListMapper on List<ChatMessage> {
return ChatCompletionMessageContentParts(partsList);
}

ChatCompletionMessage _mapAIMessage(final AIChatMessage aiChatMessage) {
return ChatCompletionMessage.assistant(
content: aiChatMessage.content,
toolCalls: aiChatMessage.toolCalls.isNotEmpty
? aiChatMessage.toolCalls
.map(_mapMessageToolCall)
.toList(growable: false)
: null,
);
}

ChatCompletionMessageToolCall _mapMessageToolCall(
final AIChatMessageToolCall toolCall,
) {
Expand All @@ -117,12 +176,26 @@ extension ChatMessageListMapper on List<ChatMessage> {
),
);
}

ChatCompletionMessage _mapToolMessage(
final ToolChatMessage toolChatMessage,
) {
return ChatCompletionMessage.tool(
toolCallId: toolChatMessage.toolCallId,
content: toolChatMessage.content,
);
}
}

extension CreateChatCompletionResponseMapper on CreateChatCompletionResponse {
ChatResult toChatResult(final String id) {
final choice = choices.first;
final msg = choice.message;

if (msg.refusal != null && msg.refusal!.isNotEmpty) {
throw OpenAIRefusalException(msg.refusal!);
}

return ChatResult(
id: id,
output: AIChatMessage(
Expand Down Expand Up @@ -211,6 +284,11 @@ extension CreateChatCompletionStreamResponseMapper
ChatResult toChatResult(final String id) {
final choice = choices.firstOrNull;
final delta = choice?.delta;

if (delta?.refusal != null && delta!.refusal!.isNotEmpty) {
throw OpenAIRefusalException(delta.refusal!);
}

return ChatResult(
id: id,
output: AIChatMessage(
Expand Down
22 changes: 22 additions & 0 deletions packages/langchain_openai/lib/src/chat_models/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,25 @@ enum ChatOpenAIServiceTier {
/// uptime SLA and no latency guarantee.
vDefault,
}

/// {@template openai_refusal_exception}
/// Exception thrown when OpenAI Structured Outputs API returns a refusal.
///
/// When using OpenAI's Structured Outputs API with user-generated input, the
/// model may occasionally refuse to fulfill the request for safety reasons.
///
/// See here for more on refusals:
/// https://platform.openai.com/docs/guides/structured-outputs/refusals
/// {@endtemplate}
class OpenAIRefusalException implements Exception {
/// {@macro openai_refusal_exception}
const OpenAIRefusalException(this.message);

/// The refusal message.
final String message;

@override
String toString() {
return 'OpenAIRefusalException: $message';
}
}
Loading