Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added AIM support for Meta Llama3 models in AWS Bedrock #2306

Merged
merged 3 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion ai-support.json
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,19 @@
}
]
},

{
"name": "Meta Llama3",
"features": [
{
"title": "Text",
"supported": true
},
{
"title": "Image",
"supported": false
}
]
},
{
"name": "Amazon Titan",
"features": [
Expand Down
10 changes: 5 additions & 5 deletions lib/llm-events/aws-bedrock/bedrock-command.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class BedrockCommand {
result = this.#body.max_tokens_to_sample
} else if (this.isClaude3() === true || this.isCohere() === true) {
result = this.#body.max_tokens
} else if (this.isLlama2() === true) {
} else if (this.isLlama() === true) {
jsumners-nr marked this conversation as resolved.
Show resolved Hide resolved
result = this.#body.max_gen_length
} else if (this.isTitan() === true) {
result = this.#body.textGenerationConfig?.maxTokenCount
Expand Down Expand Up @@ -80,7 +80,7 @@ class BedrockCommand {
this.isClaude() === true ||
this.isAi21() === true ||
this.isCohere() === true ||
this.isLlama2() === true
this.isLlama() === true
) {
result = this.#body.prompt
} else if (this.isClaude3() === true) {
Expand All @@ -104,7 +104,7 @@ class BedrockCommand {
this.isClaude3() === true ||
this.isAi21() === true ||
this.isCohere() === true ||
this.isLlama2() === true
this.isLlama() === true
) {
result = this.#body.temperature
}
Expand All @@ -131,8 +131,8 @@ class BedrockCommand {
return this.#modelId.startsWith('cohere.embed')
}

isLlama2() {
return this.#modelId.startsWith('meta.llama2')
isLlama() {
return this.#modelId.startsWith('meta.llama')
}

isTitan() {
Expand Down
4 changes: 2 additions & 2 deletions lib/llm-events/aws-bedrock/bedrock-response.js
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class BedrockResponse {
} else if (cmd.isCohere() === true) {
this.#completions = body.generations?.map((g) => g.text) ?? []
this.#id = body.id
} else if (cmd.isLlama2() === true) {
} else if (cmd.isLlama() === true) {
body.generation && this.#completions.push(body.generation)
} else if (cmd.isTitan() === true) {
this.#completions = body.results?.map((r) => r.outputText) ?? []
Expand Down Expand Up @@ -107,7 +107,7 @@ class BedrockResponse {
result = this.#parsedBody.stop_reason
} else if (cmd.isCohere() === true) {
result = this.#parsedBody.generations?.find((r) => r.finish_reason !== null)?.finish_reason
} else if (cmd.isLlama2() === true) {
} else if (cmd.isLlama() === true) {
result = this.#parsedBody.stop_reason
} else if (cmd.isTitan() === true) {
result = this.#parsedBody.results?.find((r) => r.completionReason !== null)?.completionReason
Expand Down
6 changes: 3 additions & 3 deletions lib/llm-events/aws-bedrock/stream-handler.js
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ class StreamHandler {
} else if (bedrockCommand.isCohereEmbed() === true) {
this.stopReasonKey = 'nr_none'
this.generator = handleCohereEmbed
} else if (bedrockCommand.isLlama2() === true) {
} else if (bedrockCommand.isLlama() === true) {
this.stopReasonKey = 'stop_reason'
this.generator = handleLlama2
this.generator = handleLlama
} else if (bedrockCommand.isTitan() === true) {
this.stopReasonKey = 'completionReason'
this.generator = handleTitan
Expand Down Expand Up @@ -271,7 +271,7 @@ async function* handleCohereEmbed() {
}
}

async function* handleLlama2() {
async function* handleLlama() {
let currentBody = {}
let generation = ''

Expand Down
7 changes: 5 additions & 2 deletions test/lib/aws-server-stubs/ai-server/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,11 @@ function handler(req, res) {
}

case 'meta.llama2-13b-chat-v1':
case 'meta.llama2-70b-chat-v1': {
response = responses.llama2.get(payload.prompt)
case 'meta.llama2-70b-chat-v1':
// llama3 responses are indentical, just return llama2 data
case 'meta.llama3-8b-instruct-v1:0':
case 'meta.llama3-70b-instruct-v1:0': {
response = responses.llama.get(payload.prompt)
break
}

Expand Down
4 changes: 2 additions & 2 deletions test/lib/aws-server-stubs/ai-server/responses/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ const amazon = require('./amazon')
const claude = require('./claude')
const claude3 = require('./claude3')
const cohere = require('./cohere')
const llama2 = require('./llama2')
const llama = require('./llama')

module.exports = {
ai21,
amazon,
claude,
claude3,
cohere,
llama2
llama
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
const responses = new Map()
const { contentType, reqId } = require('./constants')

responses.set('text llama2 ultimate question', {
responses.set('text llama ultimate question', {
headers: {
'content-type': contentType,
'x-amzn-requestid': reqId,
Expand All @@ -25,7 +25,7 @@ responses.set('text llama2 ultimate question', {
}
})

responses.set('text llama2 ultimate question streamed', {
responses.set('text llama ultimate question streamed', {
headers: {
'content-type': 'application/vnd.amazon.eventstream',
'x-amzn-requestid': reqId,
Expand Down Expand Up @@ -68,7 +68,7 @@ responses.set('text llama2 ultimate question streamed', {
]
})

responses.set('text llama2 ultimate question error', {
responses.set('text llama ultimate question error', {
headers: {
'content-type': contentType,
'x-amzn-requestid': reqId,
Expand Down
38 changes: 35 additions & 3 deletions test/unit/llm-events/aws-bedrock/bedrock-command.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ const llama2 = {
}
}

const llama3 = {
modelId: 'meta.llama3-8b-instruct-v1:0',
body: {
prompt: 'who are you'
}
}

const titan = {
modelId: 'amazon.titan-text-lite-v1',
body: {
Expand Down Expand Up @@ -85,7 +92,7 @@ tap.test('non-conforming command is handled gracefully', async (t) => {
'Claude3',
'Cohere',
'CohereEmbed',
'Llama2',
'Llama',
'Titan',
'TitanEmbed'
]) {
Expand Down Expand Up @@ -212,7 +219,7 @@ tap.test('cohere embed minimal command works', async (t) => {
tap.test('llama2 minimal command works', async (t) => {
t.context.updatePayload(structuredClone(llama2))
const cmd = new BedrockCommand(t.context.input)
t.equal(cmd.isLlama2(), true)
t.equal(cmd.isLlama(), true)
t.equal(cmd.maxTokens, undefined)
t.equal(cmd.modelId, llama2.modelId)
t.equal(cmd.modelType, 'completion')
Expand All @@ -226,7 +233,32 @@ tap.test('llama2 complete command works', async (t) => {
payload.body.temperature = 0.5
t.context.updatePayload(payload)
const cmd = new BedrockCommand(t.context.input)
t.equal(cmd.isLlama2(), true)
t.equal(cmd.isLlama(), true)
t.equal(cmd.maxTokens, 25)
t.equal(cmd.modelId, payload.modelId)
t.equal(cmd.modelType, 'completion')
t.equal(cmd.prompt, payload.body.prompt)
t.equal(cmd.temperature, payload.body.temperature)
})

tap.test('llama3 minimal command works', async (t) => {
t.context.updatePayload(structuredClone(llama3))
const cmd = new BedrockCommand(t.context.input)
t.equal(cmd.isLlama(), true)
t.equal(cmd.maxTokens, undefined)
t.equal(cmd.modelId, llama3.modelId)
t.equal(cmd.modelType, 'completion')
t.equal(cmd.prompt, llama3.body.prompt)
t.equal(cmd.temperature, undefined)
})

tap.test('llama3 complete command works', async (t) => {
const payload = structuredClone(llama3)
payload.body.max_gen_length = 25
payload.body.temperature = 0.5
t.context.updatePayload(payload)
const cmd = new BedrockCommand(t.context.input)
t.equal(cmd.isLlama(), true)
t.equal(cmd.maxTokens, 25)
t.equal(cmd.modelId, payload.modelId)
t.equal(cmd.modelType, 'completion')
Expand Down
18 changes: 9 additions & 9 deletions test/unit/llm-events/aws-bedrock/bedrock-response.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ const cohere = {
]
}

const llama2 = {
generation: 'llama2-response',
const llama = {
generation: 'llama-response',
stop_reason: 'done'
}

Expand Down Expand Up @@ -79,7 +79,7 @@ tap.beforeEach((t) => {
isCohere() {
return false
},
isLlama2() {
isLlama() {
return false
},
isTitan() {
Expand Down Expand Up @@ -172,8 +172,8 @@ tap.test('cohere complete responses work', async (t) => {
t.equal(res.statusCode, 200)
})

tap.test('llama2 malformed responses work', async (t) => {
t.context.bedrockCommand.isLlama2 = () => true
tap.test('llama malformed responses work', async (t) => {
t.context.bedrockCommand.isLlama = () => true
const res = new BedrockResponse(t.context)
t.same(res.completions, [])
t.equal(res.finishReason, undefined)
Expand All @@ -183,11 +183,11 @@ tap.test('llama2 malformed responses work', async (t) => {
t.equal(res.statusCode, 200)
})

tap.test('llama2 complete responses work', async (t) => {
t.context.bedrockCommand.isLlama2 = () => true
t.context.updatePayload(structuredClone(llama2))
tap.test('llama complete responses work', async (t) => {
t.context.bedrockCommand.isLlama = () => true
t.context.updatePayload(structuredClone(llama))
const res = new BedrockResponse(t.context)
t.same(res.completions, ['llama2-response'])
t.same(res.completions, ['llama-response'])
t.equal(res.finishReason, 'done')
t.same(res.headers, t.context.response.response.headers)
t.equal(res.id, undefined)
Expand Down
10 changes: 5 additions & 5 deletions test/unit/llm-events/aws-bedrock/stream-handler.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ tap.beforeEach((t) => {
isClaude3() {
return false
},
isLlama2() {
isLlama() {
return false
},
isTitan() {
Expand Down Expand Up @@ -242,15 +242,15 @@ tap.test('handles cohere embedding streams', async (t) => {
t.equal(br.statusCode, 200)
})

tap.test('handles llama2 streams', async (t) => {
t.context.passThroughParams.bedrockCommand.isLlama2 = () => true
tap.test('handles llama streams', async (t) => {
t.context.passThroughParams.bedrockCommand.isLlama = () => true
t.context.chunks = [
{ generation: '1', stop_reason: null },
{ generation: '2', stop_reason: 'done', ...t.context.metrics }
]
const handler = new StreamHandler(t.context)

t.equal(handler.generator.name, 'handleLlama2')
t.equal(handler.generator.name, 'handleLlama')
for await (const event of handler.generator()) {
t.type(event.chunk.bytes, Uint8Array)
}
Expand All @@ -267,7 +267,7 @@ tap.test('handles llama2 streams', async (t) => {
})

const bc = new BedrockCommand({
modelId: 'meta.llama2',
modelId: 'meta.llama',
body: JSON.stringify({
prompt: 'prompt',
max_gen_length: 5
Expand Down
5 changes: 3 additions & 2 deletions test/versioned/aws-sdk-v3/bedrock-chat-completions.tap.js
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ const requests = {
body: JSON.stringify({ prompt, temperature: 0.5, max_tokens: 100 }),
modelId
}),
llama2: (prompt, modelId) => ({
llama: (prompt, modelId) => ({
body: JSON.stringify({ prompt, max_gen_length: 100, temperature: 0.5 }),
modelId
})
Expand Down Expand Up @@ -98,7 +98,8 @@ tap.afterEach(async (t) => {
{ modelId: 'anthropic.claude-v2', resKey: 'claude' },
{ modelId: 'anthropic.claude-3-haiku-20240307-v1:0', resKey: 'claude3' },
{ modelId: 'cohere.command-text-v14', resKey: 'cohere' },
{ modelId: 'meta.llama2-13b-chat-v1', resKey: 'llama2' }
{ modelId: 'meta.llama2-13b-chat-v1', resKey: 'llama' },
{ modelId: 'meta.llama3-8b-instruct-v1:0', resKey: 'llama' }
].forEach(({ modelId, resKey }) => {
tap.test(`${modelId}: should properly create completion segment`, (t) => {
const { bedrock, client, responses, agent, expectedExternalPath } = t.context
Expand Down
Loading