From 99350fb63c79b9d2d63b386d808a2c9421fae8bd Mon Sep 17 00:00:00 2001 From: David Miguel Date: Wed, 21 Aug 2024 22:44:53 +0200 Subject: [PATCH] feat: Handle refusal in OpenAI's Structured Outputs API --- .../lib/src/chat_models/chat_openai.dart | 53 +------- .../lib/src/chat_models/mappers.dart | 128 ++++++++++++++---- .../lib/src/chat_models/types.dart | 22 +++ 3 files changed, 129 insertions(+), 74 deletions(-) diff --git a/packages/langchain_openai/lib/src/chat_models/chat_openai.dart b/packages/langchain_openai/lib/src/chat_models/chat_openai.dart index 0dc31168..c8a670f5 100644 --- a/packages/langchain_openai/lib/src/chat_models/chat_openai.dart +++ b/packages/langchain_openai/lib/src/chat_models/chat_openai.dart @@ -248,9 +248,10 @@ class ChatOpenAI extends BaseChatModel { final ChatOpenAIOptions? options, }) async { final completion = await _client.createChatCompletion( - request: _createChatCompletionRequest( + request: createChatCompletionRequest( input.toChatMessages(), options: options, + defaultOptions: defaultOptions, ), ); return completion.toChatResult(completion.id ?? _uuid.v4()); @@ -263,9 +264,10 @@ class ChatOpenAI extends BaseChatModel { }) { return _client .createChatCompletionStream( - request: _createChatCompletionRequest( + request: createChatCompletionRequest( input.toChatMessages(), options: options, + defaultOptions: defaultOptions, stream: true, ), ) @@ -275,53 +277,6 @@ class ChatOpenAI extends BaseChatModel { ); } - /// Creates a [CreateChatCompletionRequest] from the given input. - CreateChatCompletionRequest _createChatCompletionRequest( - final List messages, { - final ChatOpenAIOptions? options, - final bool stream = false, - }) { - final messagesDtos = messages.toChatCompletionMessages(); - final toolsDtos = - (options?.tools ?? defaultOptions.tools)?.toChatCompletionTool(); - final toolChoice = (options?.toolChoice ?? defaultOptions.toolChoice) - ?.toChatCompletionToolChoice(); - final responseFormatDto = - (options?.responseFormat ?? defaultOptions.responseFormat) - ?.toChatCompletionResponseFormat(); - final serviceTierDto = (options?.serviceTier ?? defaultOptions.serviceTier) - .toCreateChatCompletionRequestServiceTier(); - - return CreateChatCompletionRequest( - model: ChatCompletionModel.modelId( - options?.model ?? defaultOptions.model ?? defaultModel, - ), - messages: messagesDtos, - tools: toolsDtos, - toolChoice: toolChoice, - frequencyPenalty: - options?.frequencyPenalty ?? defaultOptions.frequencyPenalty, - logitBias: options?.logitBias ?? defaultOptions.logitBias, - maxTokens: options?.maxTokens ?? defaultOptions.maxTokens, - n: options?.n ?? defaultOptions.n, - presencePenalty: - options?.presencePenalty ?? defaultOptions.presencePenalty, - responseFormat: responseFormatDto, - seed: options?.seed ?? defaultOptions.seed, - stop: (options?.stop ?? defaultOptions.stop) != null - ? ChatCompletionStop.listString(options?.stop ?? defaultOptions.stop!) - : null, - temperature: options?.temperature ?? defaultOptions.temperature, - topP: options?.topP ?? defaultOptions.topP, - parallelToolCalls: - options?.parallelToolCalls ?? defaultOptions.parallelToolCalls, - serviceTier: serviceTierDto, - user: options?.user ?? defaultOptions.user, - streamOptions: - stream ? const ChatCompletionStreamOptions(includeUsage: true) : null, - ); - } - /// Tokenizes the given prompt using tiktoken with the encoding used by the /// [model]. If an encoding model is specified in [encoding] field, that /// encoding is used instead. diff --git a/packages/langchain_openai/lib/src/chat_models/mappers.dart b/packages/langchain_openai/lib/src/chat_models/mappers.dart index 5e9000c2..a2ea96f4 100644 --- a/packages/langchain_openai/lib/src/chat_models/mappers.dart +++ b/packages/langchain_openai/lib/src/chat_models/mappers.dart @@ -6,8 +6,56 @@ import 'package:langchain_core/language_models.dart'; import 'package:langchain_core/tools.dart'; import 'package:openai_dart/openai_dart.dart'; +import 'chat_openai.dart'; import 'types.dart'; +/// Creates a [CreateChatCompletionRequest] from the given input. +CreateChatCompletionRequest createChatCompletionRequest( + final List messages, { + required final ChatOpenAIOptions? options, + required final ChatOpenAIOptions defaultOptions, + final bool stream = false, +}) { + final messagesDtos = messages.toChatCompletionMessages(); + final toolsDtos = + (options?.tools ?? defaultOptions.tools)?.toChatCompletionTool(); + final toolChoice = (options?.toolChoice ?? defaultOptions.toolChoice) + ?.toChatCompletionToolChoice(); + final responseFormatDto = + (options?.responseFormat ?? defaultOptions.responseFormat) + ?.toChatCompletionResponseFormat(); + final serviceTierDto = (options?.serviceTier ?? defaultOptions.serviceTier) + .toCreateChatCompletionRequestServiceTier(); + + return CreateChatCompletionRequest( + model: ChatCompletionModel.modelId( + options?.model ?? defaultOptions.model ?? ChatOpenAI.defaultModel, + ), + messages: messagesDtos, + tools: toolsDtos, + toolChoice: toolChoice, + frequencyPenalty: + options?.frequencyPenalty ?? defaultOptions.frequencyPenalty, + logitBias: options?.logitBias ?? defaultOptions.logitBias, + maxTokens: options?.maxTokens ?? defaultOptions.maxTokens, + n: options?.n ?? defaultOptions.n, + presencePenalty: options?.presencePenalty ?? defaultOptions.presencePenalty, + responseFormat: responseFormatDto, + seed: options?.seed ?? defaultOptions.seed, + stop: (options?.stop ?? defaultOptions.stop) != null + ? ChatCompletionStop.listString(options?.stop ?? defaultOptions.stop!) + : null, + temperature: options?.temperature ?? defaultOptions.temperature, + topP: options?.topP ?? defaultOptions.topP, + parallelToolCalls: + options?.parallelToolCalls ?? defaultOptions.parallelToolCalls, + serviceTier: serviceTierDto, + user: options?.user ?? defaultOptions.user, + streamOptions: + stream ? const ChatCompletionStreamOptions(includeUsage: true) : null, + ); +} + extension ChatMessageListMapper on List { List toChatCompletionMessages() { return map(_mapMessage).toList(growable: false); @@ -15,36 +63,36 @@ extension ChatMessageListMapper on List { ChatCompletionMessage _mapMessage(final ChatMessage msg) { return switch (msg) { - final SystemChatMessage systemChatMessage => ChatCompletionMessage.system( - content: systemChatMessage.content, - ), - final HumanChatMessage humanChatMessage => ChatCompletionMessage.user( - content: switch (humanChatMessage.content) { - final ChatMessageContentText c => _mapMessageContentString(c), - final ChatMessageContentImage c => - ChatCompletionUserMessageContent.parts( - [_mapMessageContentPartImage(c)], - ), - final ChatMessageContentMultiModal c => _mapMessageContentPart(c), - }, - ), - final AIChatMessage aiChatMessage => ChatCompletionMessage.assistant( - content: aiChatMessage.content, - toolCalls: aiChatMessage.toolCalls.isNotEmpty - ? aiChatMessage.toolCalls - .map(_mapMessageToolCall) - .toList(growable: false) - : null, - ), - final ToolChatMessage toolChatMessage => ChatCompletionMessage.tool( - toolCallId: toolChatMessage.toolCallId, - content: toolChatMessage.content, - ), + final SystemChatMessage msg => _mapSystemMessage(msg), + final HumanChatMessage msg => _mapHumanMessage(msg), + final AIChatMessage msg => _mapAIMessage(msg), + final ToolChatMessage msg => _mapToolMessage(msg), CustomChatMessage() => throw UnsupportedError('OpenAI does not support custom messages'), }; } + ChatCompletionMessage _mapSystemMessage( + final SystemChatMessage systemChatMessage, + ) { + return ChatCompletionMessage.system(content: systemChatMessage.content); + } + + ChatCompletionMessage _mapHumanMessage( + final HumanChatMessage humanChatMessage, + ) { + return ChatCompletionMessage.user( + content: switch (humanChatMessage.content) { + final ChatMessageContentText c => _mapMessageContentString(c), + final ChatMessageContentImage c => + ChatCompletionUserMessageContent.parts( + [_mapMessageContentPartImage(c)], + ), + final ChatMessageContentMultiModal c => _mapMessageContentPart(c), + }, + ); + } + ChatCompletionUserMessageContentString _mapMessageContentString( final ChatMessageContentText c, ) { @@ -105,6 +153,17 @@ extension ChatMessageListMapper on List { return ChatCompletionMessageContentParts(partsList); } + ChatCompletionMessage _mapAIMessage(final AIChatMessage aiChatMessage) { + return ChatCompletionMessage.assistant( + content: aiChatMessage.content, + toolCalls: aiChatMessage.toolCalls.isNotEmpty + ? aiChatMessage.toolCalls + .map(_mapMessageToolCall) + .toList(growable: false) + : null, + ); + } + ChatCompletionMessageToolCall _mapMessageToolCall( final AIChatMessageToolCall toolCall, ) { @@ -117,12 +176,26 @@ extension ChatMessageListMapper on List { ), ); } + + ChatCompletionMessage _mapToolMessage( + final ToolChatMessage toolChatMessage, + ) { + return ChatCompletionMessage.tool( + toolCallId: toolChatMessage.toolCallId, + content: toolChatMessage.content, + ); + } } extension CreateChatCompletionResponseMapper on CreateChatCompletionResponse { ChatResult toChatResult(final String id) { final choice = choices.first; final msg = choice.message; + + if (msg.refusal != null && msg.refusal!.isNotEmpty) { + throw OpenAIRefusalException(msg.refusal!); + } + return ChatResult( id: id, output: AIChatMessage( @@ -211,6 +284,11 @@ extension CreateChatCompletionStreamResponseMapper ChatResult toChatResult(final String id) { final choice = choices.firstOrNull; final delta = choice?.delta; + + if (delta?.refusal != null && delta!.refusal!.isNotEmpty) { + throw OpenAIRefusalException(delta.refusal!); + } + return ChatResult( id: id, output: AIChatMessage( diff --git a/packages/langchain_openai/lib/src/chat_models/types.dart b/packages/langchain_openai/lib/src/chat_models/types.dart index 0c80184f..6713a56f 100644 --- a/packages/langchain_openai/lib/src/chat_models/types.dart +++ b/packages/langchain_openai/lib/src/chat_models/types.dart @@ -379,3 +379,25 @@ enum ChatOpenAIServiceTier { /// uptime SLA and no latency guarantee. vDefault, } + +/// {@template openai_refusal_exception} +/// Exception thrown when OpenAI Structured Outputs API returns a refusal. +/// +/// When using OpenAI's Structured Outputs API with user-generated input, the +/// model may occasionally refuse to fulfill the request for safety reasons. +/// +/// See here for more on refusals: +/// https://platform.openai.com/docs/guides/structured-outputs/refusals +/// {@endtemplate} +class OpenAIRefusalException implements Exception { + /// {@macro openai_refusal_exception} + const OpenAIRefusalException(this.message); + + /// The refusal message. + final String message; + + @override + String toString() { + return 'OpenAIRefusalException: $message'; + } +}