Skip to content

Commit

Permalink
feat: Added AIM support for Meta Llama3 models in AWS Bedrock (#2306)
Browse files Browse the repository at this point in the history
  • Loading branch information
amychisholm03 authored Jun 27, 2024
1 parent 0bf8908 commit ff2e509
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 37 deletions.
14 changes: 13 additions & 1 deletion ai-support.json
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,19 @@
}
]
},

{
"name": "Meta Llama3",
"features": [
{
"title": "Text",
"supported": true
},
{
"title": "Image",
"supported": false
}
]
},
{
"name": "Amazon Titan",
"features": [
Expand Down
10 changes: 5 additions & 5 deletions lib/llm-events/aws-bedrock/bedrock-command.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class BedrockCommand {
result = this.#body.max_tokens_to_sample
} else if (this.isClaude3() === true || this.isCohere() === true) {
result = this.#body.max_tokens
} else if (this.isLlama2() === true) {
} else if (this.isLlama() === true) {
result = this.#body.max_gen_length
} else if (this.isTitan() === true) {
result = this.#body.textGenerationConfig?.maxTokenCount
Expand Down Expand Up @@ -80,7 +80,7 @@ class BedrockCommand {
this.isClaude() === true ||
this.isAi21() === true ||
this.isCohere() === true ||
this.isLlama2() === true
this.isLlama() === true
) {
result = this.#body.prompt
} else if (this.isClaude3() === true) {
Expand All @@ -104,7 +104,7 @@ class BedrockCommand {
this.isClaude3() === true ||
this.isAi21() === true ||
this.isCohere() === true ||
this.isLlama2() === true
this.isLlama() === true
) {
result = this.#body.temperature
}
Expand All @@ -131,8 +131,8 @@ class BedrockCommand {
return this.#modelId.startsWith('cohere.embed')
}

isLlama2() {
return this.#modelId.startsWith('meta.llama2')
isLlama() {
return this.#modelId.startsWith('meta.llama')
}

isTitan() {
Expand Down
4 changes: 2 additions & 2 deletions lib/llm-events/aws-bedrock/bedrock-response.js
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class BedrockResponse {
} else if (cmd.isCohere() === true) {
this.#completions = body.generations?.map((g) => g.text) ?? []
this.#id = body.id
} else if (cmd.isLlama2() === true) {
} else if (cmd.isLlama() === true) {
body.generation && this.#completions.push(body.generation)
} else if (cmd.isTitan() === true) {
this.#completions = body.results?.map((r) => r.outputText) ?? []
Expand Down Expand Up @@ -107,7 +107,7 @@ class BedrockResponse {
result = this.#parsedBody.stop_reason
} else if (cmd.isCohere() === true) {
result = this.#parsedBody.generations?.find((r) => r.finish_reason !== null)?.finish_reason
} else if (cmd.isLlama2() === true) {
} else if (cmd.isLlama() === true) {
result = this.#parsedBody.stop_reason
} else if (cmd.isTitan() === true) {
result = this.#parsedBody.results?.find((r) => r.completionReason !== null)?.completionReason
Expand Down
6 changes: 3 additions & 3 deletions lib/llm-events/aws-bedrock/stream-handler.js
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ class StreamHandler {
} else if (bedrockCommand.isCohereEmbed() === true) {
this.stopReasonKey = 'nr_none'
this.generator = handleCohereEmbed
} else if (bedrockCommand.isLlama2() === true) {
} else if (bedrockCommand.isLlama() === true) {
this.stopReasonKey = 'stop_reason'
this.generator = handleLlama2
this.generator = handleLlama
} else if (bedrockCommand.isTitan() === true) {
this.stopReasonKey = 'completionReason'
this.generator = handleTitan
Expand Down Expand Up @@ -271,7 +271,7 @@ async function* handleCohereEmbed() {
}
}

async function* handleLlama2() {
async function* handleLlama() {
let currentBody = {}
let generation = ''

Expand Down
7 changes: 5 additions & 2 deletions test/lib/aws-server-stubs/ai-server/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,11 @@ function handler(req, res) {
}

case 'meta.llama2-13b-chat-v1':
case 'meta.llama2-70b-chat-v1': {
response = responses.llama2.get(payload.prompt)
case 'meta.llama2-70b-chat-v1':
// llama3 responses are indentical, just return llama2 data
case 'meta.llama3-8b-instruct-v1:0':
case 'meta.llama3-70b-instruct-v1:0': {
response = responses.llama.get(payload.prompt)
break
}

Expand Down
4 changes: 2 additions & 2 deletions test/lib/aws-server-stubs/ai-server/responses/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ const amazon = require('./amazon')
const claude = require('./claude')
const claude3 = require('./claude3')
const cohere = require('./cohere')
const llama2 = require('./llama2')
const llama = require('./llama')

module.exports = {
ai21,
amazon,
claude,
claude3,
cohere,
llama2
llama
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
const responses = new Map()
const { contentType, reqId } = require('./constants')

responses.set('text llama2 ultimate question', {
responses.set('text llama ultimate question', {
headers: {
'content-type': contentType,
'x-amzn-requestid': reqId,
Expand All @@ -25,7 +25,7 @@ responses.set('text llama2 ultimate question', {
}
})

responses.set('text llama2 ultimate question streamed', {
responses.set('text llama ultimate question streamed', {
headers: {
'content-type': 'application/vnd.amazon.eventstream',
'x-amzn-requestid': reqId,
Expand Down Expand Up @@ -68,7 +68,7 @@ responses.set('text llama2 ultimate question streamed', {
]
})

responses.set('text llama2 ultimate question error', {
responses.set('text llama ultimate question error', {
headers: {
'content-type': contentType,
'x-amzn-requestid': reqId,
Expand Down
38 changes: 35 additions & 3 deletions test/unit/llm-events/aws-bedrock/bedrock-command.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ const llama2 = {
}
}

const llama3 = {
modelId: 'meta.llama3-8b-instruct-v1:0',
body: {
prompt: 'who are you'
}
}

const titan = {
modelId: 'amazon.titan-text-lite-v1',
body: {
Expand Down Expand Up @@ -85,7 +92,7 @@ tap.test('non-conforming command is handled gracefully', async (t) => {
'Claude3',
'Cohere',
'CohereEmbed',
'Llama2',
'Llama',
'Titan',
'TitanEmbed'
]) {
Expand Down Expand Up @@ -212,7 +219,7 @@ tap.test('cohere embed minimal command works', async (t) => {
tap.test('llama2 minimal command works', async (t) => {
t.context.updatePayload(structuredClone(llama2))
const cmd = new BedrockCommand(t.context.input)
t.equal(cmd.isLlama2(), true)
t.equal(cmd.isLlama(), true)
t.equal(cmd.maxTokens, undefined)
t.equal(cmd.modelId, llama2.modelId)
t.equal(cmd.modelType, 'completion')
Expand All @@ -226,7 +233,32 @@ tap.test('llama2 complete command works', async (t) => {
payload.body.temperature = 0.5
t.context.updatePayload(payload)
const cmd = new BedrockCommand(t.context.input)
t.equal(cmd.isLlama2(), true)
t.equal(cmd.isLlama(), true)
t.equal(cmd.maxTokens, 25)
t.equal(cmd.modelId, payload.modelId)
t.equal(cmd.modelType, 'completion')
t.equal(cmd.prompt, payload.body.prompt)
t.equal(cmd.temperature, payload.body.temperature)
})

tap.test('llama3 minimal command works', async (t) => {
t.context.updatePayload(structuredClone(llama3))
const cmd = new BedrockCommand(t.context.input)
t.equal(cmd.isLlama(), true)
t.equal(cmd.maxTokens, undefined)
t.equal(cmd.modelId, llama3.modelId)
t.equal(cmd.modelType, 'completion')
t.equal(cmd.prompt, llama3.body.prompt)
t.equal(cmd.temperature, undefined)
})

tap.test('llama3 complete command works', async (t) => {
const payload = structuredClone(llama3)
payload.body.max_gen_length = 25
payload.body.temperature = 0.5
t.context.updatePayload(payload)
const cmd = new BedrockCommand(t.context.input)
t.equal(cmd.isLlama(), true)
t.equal(cmd.maxTokens, 25)
t.equal(cmd.modelId, payload.modelId)
t.equal(cmd.modelType, 'completion')
Expand Down
18 changes: 9 additions & 9 deletions test/unit/llm-events/aws-bedrock/bedrock-response.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ const cohere = {
]
}

const llama2 = {
generation: 'llama2-response',
const llama = {
generation: 'llama-response',
stop_reason: 'done'
}

Expand Down Expand Up @@ -79,7 +79,7 @@ tap.beforeEach((t) => {
isCohere() {
return false
},
isLlama2() {
isLlama() {
return false
},
isTitan() {
Expand Down Expand Up @@ -172,8 +172,8 @@ tap.test('cohere complete responses work', async (t) => {
t.equal(res.statusCode, 200)
})

tap.test('llama2 malformed responses work', async (t) => {
t.context.bedrockCommand.isLlama2 = () => true
tap.test('llama malformed responses work', async (t) => {
t.context.bedrockCommand.isLlama = () => true
const res = new BedrockResponse(t.context)
t.same(res.completions, [])
t.equal(res.finishReason, undefined)
Expand All @@ -183,11 +183,11 @@ tap.test('llama2 malformed responses work', async (t) => {
t.equal(res.statusCode, 200)
})

tap.test('llama2 complete responses work', async (t) => {
t.context.bedrockCommand.isLlama2 = () => true
t.context.updatePayload(structuredClone(llama2))
tap.test('llama complete responses work', async (t) => {
t.context.bedrockCommand.isLlama = () => true
t.context.updatePayload(structuredClone(llama))
const res = new BedrockResponse(t.context)
t.same(res.completions, ['llama2-response'])
t.same(res.completions, ['llama-response'])
t.equal(res.finishReason, 'done')
t.same(res.headers, t.context.response.response.headers)
t.equal(res.id, undefined)
Expand Down
10 changes: 5 additions & 5 deletions test/unit/llm-events/aws-bedrock/stream-handler.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ tap.beforeEach((t) => {
isClaude3() {
return false
},
isLlama2() {
isLlama() {
return false
},
isTitan() {
Expand Down Expand Up @@ -242,15 +242,15 @@ tap.test('handles cohere embedding streams', async (t) => {
t.equal(br.statusCode, 200)
})

tap.test('handles llama2 streams', async (t) => {
t.context.passThroughParams.bedrockCommand.isLlama2 = () => true
tap.test('handles llama streams', async (t) => {
t.context.passThroughParams.bedrockCommand.isLlama = () => true
t.context.chunks = [
{ generation: '1', stop_reason: null },
{ generation: '2', stop_reason: 'done', ...t.context.metrics }
]
const handler = new StreamHandler(t.context)

t.equal(handler.generator.name, 'handleLlama2')
t.equal(handler.generator.name, 'handleLlama')
for await (const event of handler.generator()) {
t.type(event.chunk.bytes, Uint8Array)
}
Expand All @@ -267,7 +267,7 @@ tap.test('handles llama2 streams', async (t) => {
})

const bc = new BedrockCommand({
modelId: 'meta.llama2',
modelId: 'meta.llama',
body: JSON.stringify({
prompt: 'prompt',
max_gen_length: 5
Expand Down
5 changes: 3 additions & 2 deletions test/versioned/aws-sdk-v3/bedrock-chat-completions.tap.js
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ const requests = {
body: JSON.stringify({ prompt, temperature: 0.5, max_tokens: 100 }),
modelId
}),
llama2: (prompt, modelId) => ({
llama: (prompt, modelId) => ({
body: JSON.stringify({ prompt, max_gen_length: 100, temperature: 0.5 }),
modelId
})
Expand Down Expand Up @@ -98,7 +98,8 @@ tap.afterEach(async (t) => {
{ modelId: 'anthropic.claude-v2', resKey: 'claude' },
{ modelId: 'anthropic.claude-3-haiku-20240307-v1:0', resKey: 'claude3' },
{ modelId: 'cohere.command-text-v14', resKey: 'cohere' },
{ modelId: 'meta.llama2-13b-chat-v1', resKey: 'llama2' }
{ modelId: 'meta.llama2-13b-chat-v1', resKey: 'llama' },
{ modelId: 'meta.llama3-8b-instruct-v1:0', resKey: 'llama' }
].forEach(({ modelId, resKey }) => {
tap.test(`${modelId}: should properly create completion segment`, (t) => {
const { bedrock, client, responses, agent, expectedExternalPath } = t.context
Expand Down

0 comments on commit ff2e509

Please sign in to comment.