From d254aa5e3a6db9e32ffb7b99f2cf13e88cde4ec2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=91=E8=BD=BB=E7=8B=82?= <1677568218@qq.com> Date: Fri, 11 Aug 2023 13:50:53 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=97=AD=E5=8C=85=E6=94=B9=E6=88=90cla?= =?UTF-8?q?ss?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- package/motion/__tests__/function.test.ts | 2 +- package/motion/__tests__/sentiment.test.ts | 24 ++--- package/motion/function.ts | 27 +++--- package/motion/typechat.ts | 103 ++++++--------------- 4 files changed, 48 insertions(+), 108 deletions(-) diff --git a/package/motion/__tests__/function.test.ts b/package/motion/__tests__/function.test.ts index 56cc4b1..a0414bc 100644 --- a/package/motion/__tests__/function.test.ts +++ b/package/motion/__tests__/function.test.ts @@ -54,7 +54,7 @@ const reqMessages: messagesType = [ "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.", }, ]; -const chain = FunctionChain(llm); +const chain = new FunctionChain(llm); const res = await chain.call({ request: { role: "user", content: "What's the weather like today" }, prompt: reqMessages, diff --git a/package/motion/__tests__/sentiment.test.ts b/package/motion/__tests__/sentiment.test.ts index 3316eb4..9d4a8eb 100644 --- a/package/motion/__tests__/sentiment.test.ts +++ b/package/motion/__tests__/sentiment.test.ts @@ -1,24 +1,18 @@ import { LLM } from "../../utils/index.js"; import { TypeScriptChain } from "../index.js"; const llm = new LLM({}); -// const schema = ` +const schema = ` -// // The following is a schema definition for respond with the most concise words, each word carries rich meaning. The content of the response is a joking expression, reflecting the way some young people express their emotions and attitudes on the internet. -// export interface SentimentResponse { -// // On the Internet, when being criticized, don't run or be afraid. Remember the six-character mantra. Looking at the other side, as soon as they speak, use the word "classic" to counter. If you want to criticize, they want to praise, the word "filial" can piss them off. If they want to argue about truth and falsehood, saying "impatient" will make them collapse. If you don't understand their thoughts, if you can't "hold back", then "laugh". If they are right and you have nothing to say, typing "win" will make you "win". Don't be afraid of them, you have culture, only six words can break everything. Don't care about them, they have many moves, but rely on six words to conquer the world. -// sentiment: "classic" | "hold back" | "filial" | "happy" | "impatient" | "win"; // classic: refers to the classics, used to express approval or praise someone or something. filial: refers to filial piety, used to express respect or gratitude to parents. impatient: means impatient, used to express dissatisfaction or urge someone or something. happy: means happy, used to express joy or ridicule someone or something. hold back: means can't hold back the laughter, used to express the urge to laugh in response to someone or something. win: means win, used to express success or luck for oneself or others. -// }`; -// request: messageType | string; -// prompt?: messageType[]; -// schema?: string; -// typeName?: string; -// bound?: boolean; -// verbose?: boolean; -const chain = TypeScriptChain(llm); +// The following is a schema definition for respond with the most concise words, each word carries rich meaning. The content of the response is a joking expression, reflecting the way some young people express their emotions and attitudes on the internet. +export interface SentimentResponse { + // On the Internet, when being criticized, don't run or be afraid. Remember the six-character mantra. Looking at the other side, as soon as they speak, use the word "classic" to counter. If you want to criticize, they want to praise, the word "filial" can piss them off. If they want to argue about truth and falsehood, saying "impatient" will make them collapse. If you don't understand their thoughts, if you can't "hold back", then "laugh". If they are right and you have nothing to say, typing "win" will make you "win". Don't be afraid of them, you have culture, only six words can break everything. Don't care about them, they have many moves, but rely on six words to conquer the world. + sentiment: "classic" | "hold back" | "filial" | "happy" | "impatient" | "win"; // classic: refers to the classics, used to express approval or praise someone or something. filial: refers to filial piety, used to express respect or gratitude to parents. impatient: means impatient, used to express dissatisfaction or urge someone or something. happy: means happy, used to express joy or ridicule someone or something. hold back: means can't hold back the laughter, used to express the urge to laugh in response to someone or something. win: means win, used to express success or luck for oneself or others. +}`; +const chain = new TypeScriptChain(llm); const res = await chain.call({ request: "answer yes!", - // schema, - // typeName: "SentimentResponse", + schema, + typeName: "SentimentResponse", bound: false, verbose: true, }); diff --git a/package/motion/function.ts b/package/motion/function.ts index 465d6bc..0fe971d 100644 --- a/package/motion/function.ts +++ b/package/motion/function.ts @@ -1,4 +1,4 @@ -import { Result, success, Error, LLM } from "../utils/index.js"; +import { success, Error, LLM } from "../utils/index.js"; import { createMessage, messageType } from "../attention/index.js"; import { functionsType, function_callType } from "../utils/index.js"; export interface FunctionCallSchema { @@ -8,19 +8,14 @@ export interface FunctionCallSchema { function_call?: function_callType; verbose?: boolean; } -export interface FunctionSchema { - llm: LLM; - call(params: FunctionCallSchema): Promise>; -} -export function FunctionChain(llm: LLM): FunctionSchema { - const Function: FunctionSchema = { - llm, - call, - }; - return Function; +export class FunctionChain { + llm: LLM; + constructor(llm: LLM) { + this.llm = llm; + } - async function call(params: FunctionCallSchema): Promise> { + async call(params: FunctionCallSchema) { const { request, prompt, functions, function_call, verbose } = params; let messages: messageType[] = []; !!prompt && (messages = prompt); @@ -29,25 +24,25 @@ export function FunctionChain(llm: LLM): FunctionSchema { } else { messages.push(request); } - const res = await llm.chat({ + const res = await this.llm.chat({ modelName: "gpt-3.5-turbo", messages: messages, functions: functions || undefined, function_call: function_call || undefined, }); if (verbose) { - llm.printMessage(); + this.llm.printMessage(); } const responseText = res.choices[0].message.content; if (!responseText && !!res.choices[0].message.function_call) { const return_res = JSON.parse( res.choices[0].message.function_call.arguments as string, ); - return success(return_res as T); + return success(return_res); } if (!responseText) { return { success: false, message: responseText } as Error; } - return success(responseText as unknown as T); + return success(responseText); } } diff --git a/package/motion/typechat.ts b/package/motion/typechat.ts index a131997..30bb0d8 100644 --- a/package/motion/typechat.ts +++ b/package/motion/typechat.ts @@ -13,49 +13,6 @@ export interface TypeScriptChainCallSchema { bound?: boolean; verbose?: boolean; } -export interface TypeScriptChainSchema { - /** - * The associated `LLM`. - */ - llm: LLM; - /** - * A boolean indicating whether to attempt repairing JSON objects that fail to validate. The default is `true`, - * but an application can set the property to `false` to disable repair attempts. - */ - attemptRepair: boolean; - /** - * A boolean indicating whether to delete properties with null values from parsed JSON objects. Some language - * models (e.g. gpt-3.5-turbo) have a tendency to assign null values to optional properties instead of omitting - * them. The default for this property is `false`, but an application can set the property to `true` for schemas - * that don't permit null values. - */ - stripNulls: boolean; - /** - * Creates an AI language model prompt from the given request. This function is called by `completeAndValidate` - * to obtain the prompt. An application can assign a new function to provide a different prompt. - * @param request The natural language request. - * @returns A prompt that combines the request with the schema and type name of the underlying validator. - */ - createRequestPrompt(validator: TypeChatJsonValidator): messageType; - /** - * Creates a repair prompt to append to an original prompt/response in order to repair a JSON object that - * failed to validate. This function is called by `completeAndValidate` when `attemptRepair` is true and the - * JSON object produced by the original prompt failed to validate. An application can assign a new function - * to provide a different repair prompt. - * @param validationError The error message returned by the validator. - * @returns A repair prompt constructed from the error message. - */ - createRepairPrompt(validationError: string): messageType; - /** - * Translates a natural language request into an object of type `T`. If the JSON object returned by - * the language model fails to validate and the `attemptRepair` property is `true`, a second - * attempt to translate the request will be made. The prompt for the second attempt will include the - * diagnostics produced for the first attempt. This often helps produce a valid instance. - * @param prompt The natural language request. - * @returns A promise for the resulting object. - */ - call(params: TypeScriptChainCallSchema): Promise>; -} /** * Creates an object that can translate natural language requests into JSON objects of the given type. @@ -66,22 +23,16 @@ export interface TypeScriptChainSchema { * @param typeName The name of the JSON target type in the schema. * @returns A `TypeChatJsonTranslator` instance. */ -export function TypeScriptChain( - llm: LLM, -): TypeScriptChainSchema { - const typeChat: TypeScriptChainSchema = { - llm, - attemptRepair: true, - stripNulls: false, - createRequestPrompt, - createRepairPrompt, - call, - }; - return typeChat; +export class TypeScriptChain { + llm: LLM; + attemptRepair: boolean; + + constructor(llm: LLM) { + this.llm = llm; + this.attemptRepair = true; + } - function createRequestPrompt( - validator: TypeChatJsonValidator, - ): messageType { + createRequestPrompt(validator: TypeChatJsonValidator): messageType { return createMessage( "system", `\nYou need to process user requests and then translates result into JSON objects of type "${validator.typeName}" according to the following TypeScript definitions:\n` + @@ -91,7 +42,7 @@ export function TypeScriptChain( ); } - function createRepairPrompt(validationError: string) { + createRepairPrompt(validationError: string): messageType { return createMessage( "system", `The JSON object is invalid for the following reason:\n` + @@ -101,14 +52,14 @@ export function TypeScriptChain( ); } - async function call(params: TypeScriptChainCallSchema): Promise> { + async call(params: TypeScriptChainCallSchema): Promise> { const { request, prompt, schema, typeName, bound, verbose } = params; - let validator: TypeChatJsonValidator | undefined = undefined, + let validator: TypeChatJsonValidator | undefined = undefined, resPrompt: messageType[] = [], request_: messageType | string = request; !!schema && !!typeName && - (validator = createJsonValidator(schema, typeName)); + (validator = createJsonValidator(schema, typeName)); !!prompt && (resPrompt = prompt); if (bound) { @@ -120,39 +71,39 @@ export function TypeScriptChain( // 如果是字符串,转换成消息对象 if (typeof request_ === "string") { resPrompt.push(createMessage("user", request_)); - !!validator && resPrompt.push(typeChat.createRequestPrompt(validator)); + !!validator && resPrompt.push(this.createRequestPrompt(validator)); } else { resPrompt.push(request_); - !!validator && resPrompt.push(typeChat.createRequestPrompt(validator)); + !!validator && resPrompt.push(this.createRequestPrompt(validator)); } } - let attemptRepair = typeChat.attemptRepair; + let attemptRepair = this.attemptRepair; while (true) { - let response = await llm.chat({ + let response = await this.llm.chat({ messages: resPrompt, }); let responseText = response.choices[0].message.content; // responseText = '{ "sentiment": "play"}'; if (!responseText) { if (verbose) { - llm.printMessage(); + this.llm.printMessage(); } return { success: false, message: responseText } as Error; } if (!validator || !schema || !typeName) { if (verbose) { - llm.printMessage(); + this.llm.printMessage(); } - return { success: true, data: responseText } as unknown as Result; + return { success: true, data: responseText } as unknown as Result; } if (bound) { resPrompt.push(createMessage("assistant", responseText)); - !!validator && resPrompt.push(typeChat.createRequestPrompt(validator)); - response = await llm.chat({ messages: resPrompt }); + !!validator && resPrompt.push(this.createRequestPrompt(validator)); + response = await this.llm.chat({ messages: resPrompt }); responseText = response.choices[0].message.content; if (!responseText) { if (verbose) { - llm.printMessage(); + this.llm.printMessage(); } return { success: false, message: responseText } as Error; } @@ -162,7 +113,7 @@ export function TypeScriptChain( const endIndex = responseText.lastIndexOf("}"); if (!(startIndex >= 0 && endIndex > startIndex)) { if (verbose) { - llm.printMessage(); + this.llm.printMessage(); } return error(`Response is not JSON:\n${responseText}`); } @@ -170,13 +121,13 @@ export function TypeScriptChain( const validation = validator.validate(jsonText); if (validation.success) { if (verbose) { - llm.printMessage(); + this.llm.printMessage(); } return validation; } if (!attemptRepair) { if (verbose) { - llm.printMessage(); + this.llm.printMessage(); } return error( `JSON validation failed: ${validation.message}\n${jsonText}`, @@ -184,7 +135,7 @@ export function TypeScriptChain( } // resPrompt.push(createMessage("user", responseText)); resPrompt = []; // 继续对话 - resPrompt.push(typeChat.createRepairPrompt(validation.message)); + resPrompt.push(this.createRepairPrompt(validation.message)); attemptRepair = false; } }