Skip to content

Commit

Permalink
add support for azure managed identity with identityTokenProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
ShervK committed Nov 27, 2024
1 parent 98c3e5d commit 4804e03
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 8 deletions.
7 changes: 7 additions & 0 deletions content/providers/01-ai-sdk-providers/02-azure.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ You can use the following optional settings to customize the OpenAI provider ins
API key that is being sent using the `api-key` header.
It defaults to the `AZURE_API_KEY` environment variable.

- **identityTokenProvider** _Promise<string>_

A function that returns a promise resolving to a token to be sent using the `Authorization` header.
This can be used with the `@azure/identity` package and `getBearerTokenProvider` function to authenticate with Azure Managed Identity.

If both `apiKey` and `identityTokenProvider` are provided, the `identityTokenProvider` takes precedence.

- **apiVersion** _string_

Sets a custom [api version](https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation).
Expand Down
178 changes: 178 additions & 0 deletions packages/azure/src/azure-openai-provider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,67 @@ describe('chat', () => {
'https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-01-preview',
);
});

it('should use bearer token from identityTokenProvider', async () => {
prepareJsonResponse();

const provider = createAzure({
resourceName: 'test-resource',
identityTokenProvider: async () => 'test-token',
});

await provider('test-deployment').doGenerate({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});

const requestHeaders = await server.getRequestHeaders();
expect(requestHeaders).toStrictEqual({
'content-type': 'application/json',
authorization: 'Bearer test-token',
});
});

it('should throw error if identity token is invalid', async () => {
prepareJsonResponse();

const provider = createAzure({
resourceName: 'test-resource',
identityTokenProvider: async () => 1 as any,
});

await expect(
provider('test-deployment').doGenerate({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
}),
).rejects.toThrow(
'Invalid Azure Managed Identity token format: token must be a non-empty string. Received: 1',
);
});

it('should throw error if identity token provider fails', async () => {
prepareJsonResponse();

const provider = createAzure({
resourceName: 'test-resource',
identityTokenProvider: async () => {
throw new Error('Network error');
},
});

await expect(
provider('test-deployment').doGenerate({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
}),
).rejects.toThrow(
'Failed to fetch Azure Managed Identity token: Network error',
);
});
});
});

Expand Down Expand Up @@ -227,6 +288,67 @@ describe('completion', () => {
'custom-request-header': 'request-header-value',
});
});

it('should use bearer token from identityTokenProvider', async () => {
prepareJsonCompletionResponse({ content: 'Hello World!' });

const provider = createAzure({
resourceName: 'test-resource',
identityTokenProvider: async () => 'test-token',
});

await provider.completion('gpt-35-turbo-instruct').doGenerate({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});

const requestHeaders = await server.getRequestHeaders();
expect(requestHeaders).toStrictEqual({
'content-type': 'application/json',
authorization: 'Bearer test-token',
});
});

it('should throw error if identity token is invalid', async () => {
prepareJsonCompletionResponse({ content: 'Hello World!' });

const provider = createAzure({
resourceName: 'test-resource',
identityTokenProvider: async () => 1 as any,
});

await expect(
provider.completion('gpt-35-turbo-instruct').doGenerate({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
}),
).rejects.toThrow(
'Invalid Azure Managed Identity token format: token must be a non-empty string. Received: 1',
);
});

it('should throw error if identity token provider fails', async () => {
prepareJsonCompletionResponse({ content: 'Hello World!' });

const provider = createAzure({
resourceName: 'test-resource',
identityTokenProvider: async () => {
throw new Error('Network error');
},
});

await expect(
provider.completion('gpt-35-turbo-instruct').doGenerate({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
}),
).rejects.toThrow(
'Failed to fetch Azure Managed Identity token: Network error',
);
});
});
});

Expand Down Expand Up @@ -303,5 +425,61 @@ describe('embedding', () => {
'custom-request-header': 'request-header-value',
});
});

it('should use bearer token from identityTokenProvider', async () => {
prepareJsonResponse();

const provider = createAzure({
resourceName: 'test-resource',
identityTokenProvider: async () => 'test-token',
});

await provider.embedding('my-embedding').doEmbed({
values: testValues,
});

const requestHeaders = await server.getRequestHeaders();

expect(requestHeaders).toStrictEqual({
'content-type': 'application/json',
authorization: 'Bearer test-token',
});
});

it('should throw error if identity token is invalid', async () => {
prepareJsonResponse();

const provider = createAzure({
resourceName: 'test-resource',
identityTokenProvider: async () => 1 as any,
});

await expect(
provider.embedding('my-embedding').doEmbed({
values: testValues,
}),
).rejects.toThrow(
'Invalid Azure Managed Identity token format: token must be a non-empty string. Received: 1',
);
});

it('should throw error if identity token provider fails', async () => {
prepareJsonResponse();

const provider = createAzure({
resourceName: 'test-resource',
identityTokenProvider: async () => {
throw new Error('Network error');
},
});

await expect(
provider.embedding('my-embedding').doEmbed({
values: testValues,
}),
).rejects.toThrow(
'Failed to fetch Azure Managed Identity token: Network error',
);
});
});
});
55 changes: 47 additions & 8 deletions packages/azure/src/azure-openai-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ or to provide a custom fetch implementation for e.g. testing.
Custom api version to use. Defaults to `2024-10-01-preview`.
*/
apiVersion?: string;

/**
Function to fetch an authorization token using Azure Managed Identity using the `@azure/identity` package. If provided, the `apiKey` is ignored.
*/
identityTokenProvider?: () => Promise<string>;
}

/**
Expand All @@ -107,11 +112,15 @@ export function createAzure(
options: AzureOpenAIProviderSettings = {},
): AzureOpenAIProvider {
const getHeaders = () => ({
'api-key': loadApiKey({
apiKey: options.apiKey,
environmentVariableName: 'AZURE_API_KEY',
description: 'Azure OpenAI',
}),
...(options.identityTokenProvider
? {}
: {
'api-key': loadApiKey({
apiKey: options.apiKey,
environmentVariableName: 'AZURE_API_KEY',
description: 'Azure OpenAI',
}),
}),
...options.headers,
});

Expand All @@ -129,6 +138,36 @@ export function createAzure(
? `${options.baseURL}/${modelId}${path}?api-version=${apiVersion}`
: `https://${getResourceName()}.openai.azure.com/openai/deployments/${modelId}${path}?api-version=${apiVersion}`;

const wrappedFetch = async function (...args: Parameters<FetchFunction>) {
if (options.identityTokenProvider) {
const [input, init] = args;
let token: string;

try {
token = await options.identityTokenProvider();
} catch (error) {
throw new Error(
'Failed to fetch Azure Managed Identity token: ' +
(error as Error).message,
);
}
if (!token || typeof token !== 'string') {
throw new Error(
`Invalid Azure Managed Identity token format: token must be a non-empty string. Received: ${token}`,
);
}

return (options.fetch || fetch)(input, {
...init,
headers: {
...init?.headers,
Authorization: `Bearer ${token}`,
},
});
}
return (options.fetch || fetch)(...args);
} as FetchFunction;

const createChatModel = (
deploymentName: string,
settings: OpenAIChatSettings = {},
Expand All @@ -138,7 +177,7 @@ export function createAzure(
url,
headers: getHeaders,
compatibility: 'strict',
fetch: options.fetch,
fetch: wrappedFetch,
});

const createCompletionModel = (
Expand All @@ -150,7 +189,7 @@ export function createAzure(
url,
compatibility: 'strict',
headers: getHeaders,
fetch: options.fetch,
fetch: wrappedFetch,
});

const createEmbeddingModel = (
Expand All @@ -161,7 +200,7 @@ export function createAzure(
provider: 'azure-openai.embeddings',
headers: getHeaders,
url,
fetch: options.fetch,
fetch: wrappedFetch,
});

const provider = function (
Expand Down

0 comments on commit 4804e03

Please sign in to comment.