Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(json-schema): correct handling of nested recursive schemas #992

Merged
merged 3 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions src/_vendor/zod-to-json-schema/Options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ export type Options<Target extends Targets = 'jsonSchema7'> = {
openaiStrictMode?: boolean;
};

export const defaultOptions: Options = {
const defaultOptions: Omit<Options, 'definitions' | 'basePath'> = {
name: undefined,
$refStrategy: 'root',
basePath: ['#'],
effectStrategy: 'input',
pipeStrategy: 'all',
dateStrategy: 'format:date-time',
Expand All @@ -51,7 +50,6 @@ export const defaultOptions: Options = {
definitionPath: 'definitions',
target: 'jsonSchema7',
strictUnions: false,
definitions: {},
errorMessages: false,
markdownDescription: false,
patternStrategy: 'escape',
Expand All @@ -63,13 +61,20 @@ export const defaultOptions: Options = {

export const getDefaultOptions = <Target extends Targets>(
options: Partial<Options<Target>> | string | undefined,
) =>
(typeof options === 'string' ?
{
...defaultOptions,
name: options,
}
: {
...defaultOptions,
...options,
}) as Options<Target>;
) => {
// We need to add `definitions` here as we may mutate it
return (
typeof options === 'string' ?
{
...defaultOptions,
basePath: ['#'],
definitions: {},
name: options,
}
: {
...defaultOptions,
basePath: ['#'],
definitions: {},
...options,
}) as Options<Target>;
};
28 changes: 21 additions & 7 deletions src/_vendor/zod-to-json-schema/zodToJsonSchema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,28 @@ const zodToJsonSchema = <Target extends Targets = 'jsonSchema7'>(
}

const definitions: Record<string, any> = {};
const processedDefinitions = new Set();

for (const [name, zodSchema] of Object.entries(refs.definitions)) {
definitions[name] =
parseDef(
zodDef(zodSchema),
{ ...refs, currentPath: [...refs.basePath, refs.definitionPath, name] },
true,
) ?? {};
// the call to `parseDef()` here might itself add more entries to `.definitions`
// so we need to continually evaluate definitions until we've resolved all of them
//
// we have a generous iteration limit here to avoid blowing up the stack if there
// are any bugs that would otherwise result in us iterating indefinitely
for (let i = 0; i < 500; i++) {
const newDefinitions = Object.entries(refs.definitions).filter(
([key]) => !processedDefinitions.has(key),
);
if (newDefinitions.length === 0) break;

for (const [key, schema] of newDefinitions) {
definitions[key] =
parseDef(
zodDef(schema),
{ ...refs, currentPath: [...refs.basePath, refs.definitionPath, key] },
true,
) ?? {};
processedDefinitions.add(key);
}
}

return definitions;
Expand Down
28 changes: 28 additions & 0 deletions tests/lib/__snapshots__/parser.test.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,34 @@ exports[`.parse() zod nested schema extraction 2`] = `
"
`;

exports[`.parse() zod recursive schema extraction 2`] = `
"{
"id": "chatcmpl-9vdbw9dekyUSEsSKVQDhTxA2RCxcK",
"object": "chat.completion",
"created": 1723523988,
"model": "gpt-4o-2024-08-06",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "{\\"linked_list\\":{\\"value\\":1,\\"next\\":{\\"value\\":2,\\"next\\":{\\"value\\":3,\\"next\\":{\\"value\\":4,\\"next\\":{\\"value\\":5,\\"next\\":null}}}}}}",
"refusal": null
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 40,
"completion_tokens": 38,
"total_tokens": 78
},
"system_fingerprint": "fp_2a322c9ffc"
}
"
`;

exports[`.parse() zod top-level recursive schemas 1`] = `
"{
"id": "chatcmpl-9uLhw79ArBF4KsQQOlsoE68m6vh6v",
Expand Down
182 changes: 175 additions & 7 deletions tests/lib/parser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -525,13 +525,6 @@ describe('.parse()', () => {
"$schema": "http://json-schema.org/draft-07/schema#",
"additionalProperties": false,
"definitions": {
"contactPerson_properties_person1_properties_name": {
"type": "string",
},
"contactPerson_properties_person1_properties_phone_number": {
"nullable": true,
"type": "string",
},
"query": {
"additionalProperties": false,
"properties": {
Expand Down Expand Up @@ -616,6 +609,21 @@ describe('.parse()', () => {
},
],
},
"query_properties_fields_items_anyOf_0_properties_metadata_anyOf_0": {
"additionalProperties": false,
"properties": {
"foo": {
"$ref": "#/definitions/query_properties_fields_items_anyOf_0_properties_metadata_anyOf_0_properties_foo",
},
},
"required": [
"foo",
],
"type": "object",
},
"query_properties_fields_items_anyOf_0_properties_metadata_anyOf_0_properties_foo": {
"type": "string",
},
},
"properties": {
"fields": {
Expand Down Expand Up @@ -783,5 +791,165 @@ describe('.parse()', () => {
}
`);
});

test('recursive schema extraction', async () => {
const baseLinkedListNodeSchema = z.object({
value: z.number(),
});
type LinkedListNode = z.infer<typeof baseLinkedListNodeSchema> & {
next: LinkedListNode | null;
};
const linkedListNodeSchema: z.ZodType<LinkedListNode> = baseLinkedListNodeSchema.extend({
next: z.lazy(() => z.union([linkedListNodeSchema, z.null()])),
});

// Define the main schema
const mainSchema = z.object({
linked_list: linkedListNodeSchema,
});

expect(zodResponseFormat(mainSchema, 'query').json_schema.schema).toMatchInlineSnapshot(`
{
"$schema": "http://json-schema.org/draft-07/schema#",
"additionalProperties": false,
"definitions": {
"query": {
"additionalProperties": false,
"properties": {
"linked_list": {
"additionalProperties": false,
"properties": {
"next": {
"anyOf": [
{
"$ref": "#/definitions/query_properties_linked_list",
},
{
"type": "null",
},
],
},
"value": {
"type": "number",
},
},
"required": [
"value",
"next",
],
"type": "object",
},
},
"required": [
"linked_list",
],
"type": "object",
},
"query_properties_linked_list": {
"additionalProperties": false,
"properties": {
"next": {
"$ref": "#/definitions/query_properties_linked_list_properties_next",
},
"value": {
"$ref": "#/definitions/query_properties_linked_list_properties_value",
},
},
"required": [
"value",
"next",
],
"type": "object",
},
"query_properties_linked_list_properties_next": {
"anyOf": [
{
"$ref": "#/definitions/query_properties_linked_list",
},
{
"type": "null",
},
],
},
"query_properties_linked_list_properties_value": {
"type": "number",
},
},
"properties": {
"linked_list": {
"additionalProperties": false,
"properties": {
"next": {
"anyOf": [
{
"$ref": "#/definitions/query_properties_linked_list",
},
{
"type": "null",
},
],
},
"value": {
"type": "number",
},
},
"required": [
"value",
"next",
],
"type": "object",
},
},
"required": [
"linked_list",
],
"type": "object",
}
`);

const completion = await makeSnapshotRequest(
(openai) =>
openai.beta.chat.completions.parse({
model: 'gpt-4o-2024-08-06',
messages: [
{
role: 'system',
content:
"You are a helpful assistant. Generate a data model according to the user's instructions.",
},
{ role: 'user', content: 'create a linklist from 1 to 5' },
],
response_format: zodResponseFormat(mainSchema, 'query'),
}),
2,
);

expect(completion.choices[0]?.message).toMatchInlineSnapshot(`
{
"content": "{"linked_list":{"value":1,"next":{"value":2,"next":{"value":3,"next":{"value":4,"next":{"value":5,"next":null}}}}}}",
"parsed": {
"linked_list": {
"next": {
"next": {
"next": {
"next": {
"next": null,
"value": 5,
},
"value": 4,
},
"value": 3,
},
"value": 2,
},
"value": 1,
},
},
"refusal": null,
"role": "assistant",
"tool_calls": [],
}
`);
});
});
});