Skip to content

Commit

Permalink
Replace legacy Model.Message type with Msg
Browse files Browse the repository at this point in the history
  • Loading branch information
rileytomasek committed Oct 17, 2024
1 parent 7947e4b commit d419fd0
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 17 deletions.
10 changes: 5 additions & 5 deletions src/model/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { type PartialDeep, type SetOptional } from 'type-fest';

import { createOpenAIClient } from './clients/openai.js';
import { AbstractModel, type ModelArgs } from './model.js';
import { type Model } from './types.js';
import { type Model, type Msg } from './types.js';
import { calculateCost } from './utils/calculate-cost.js';
import { deepMerge, mergeEvents, type Prettify } from './utils/helpers.js';
import { MsgUtil } from './utils/message-util.js';
Expand Down Expand Up @@ -276,7 +276,7 @@ export class ChatModel<
/**
* Verbose logging for debugging prompts
*/
function logInput(args: { params: { messages: Model.Message[] } }) {
function logInput(args: { params: { messages: Msg[] } }) {
console.debug(`-----> [Request] ----->`);
console.debug();
args.params.messages.forEach(logMessage);
Expand All @@ -290,10 +290,10 @@ function logResponse(args: {
};
cached: boolean;
latency?: number;
choices: { message: Model.Message }[];
choices: { message: Msg }[];
cost?: number;
};
params: { messages: Model.Message[] };
params: { messages: Msg[] };
}) {
const { usage, cost, latency, choices } = args.response;
const tokens = {
Expand All @@ -314,7 +314,7 @@ function logResponse(args: {
logMessage(message, args.params.messages.length + 1);
}

function logMessage(message: Model.Message, index: number) {
function logMessage(message: Msg, index: number) {
console.debug(
`[${index}] ${message.role.toUpperCase()}:${
'name' in message ? ` (${message.name}) ` : ''
Expand Down
10 changes: 3 additions & 7 deletions src/model/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ export namespace Model {
streamChatCompletion: OpenAIClient['streamChatCompletion'];
};
export interface Run extends Base.Run {
messages: Model.Message[];
messages: Msg[];
}
export interface Config extends Base.Config {
/** Handle new chunk from streaming requests. */
Expand Down Expand Up @@ -291,8 +291,8 @@ export namespace Model {
/** Decode an array of integer tokens into a string */
decode(tokens: number[] | Uint32Array): string;
/**
* Count the number of tokens in a string or ChatMessage(s).
* A single ChatMessage is counted as a completion and an array as a prompt.
* Count the number of tokens in a string or message(s).
* A single Msg is counted as a completion and an array as a prompt.
* Strings are counted as is.
*/
countTokens(input?: string | Msg | Msg[]): number;
Expand All @@ -307,10 +307,6 @@ export namespace Model {
}): string;
}

// TODO: replace iwth just Msg
/** Primary message type for chat models */
export type Message = Msg;

/** The provider of the model (eg: OpenAI) */
export type Provider = (string & {}) | 'openai' | 'custom';

Expand Down
11 changes: 6 additions & 5 deletions src/model/utils/tokenizer.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import { type ChatMessage } from 'openai-fetch';
import {
encoding_for_model,
type Tiktoken,
type TiktokenModel,
} from 'tiktoken';

import { type Model } from '../types.js';
import { type Model, type Msg } from '../types.js';

const GPT_4_MODELS = [
'gpt-4',
Expand Down Expand Up @@ -56,7 +55,7 @@ class Tokenizer implements Model.ITokenizer {
* Count the number of tokens in a string or ChatMessage(s)
* A single message is counted as a completion and an array as a prompt
**/
countTokens(input?: string | ChatMessage | ChatMessage[]): number {
countTokens(input?: string | Msg | Msg[]): number {
if (!input) return 0;
if (typeof input === 'string') {
return this.tiktoken.encode(input).length;
Expand All @@ -79,13 +78,15 @@ class Tokenizer implements Model.ITokenizer {
// For 4, the name and role are included
// Details here: https://github.com/openai/openai-python/blob/main/chatml.md
numTokens += 1; // role
if (message.name) {
if ('name' in message) {
// No idea why this, but tested with many examples and it works...
numTokens += this.countTokens(`${message.name}`) + 1;
}
} else {
// For 3.5, the name replaces the role if it's present
numTokens += this.countTokens(message.name || message.role);
numTokens += this.countTokens(
'name' in message ? message.name : message.role
);
}
}

Expand Down

0 comments on commit d419fd0

Please sign in to comment.