Skip to content

Commit

Permalink
refactor(streaming): change Stream constructor signature (#370)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot authored Oct 12, 2023
1 parent b176703 commit 71984ed
Show file tree
Hide file tree
Showing 14 changed files with 192 additions and 55 deletions.
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@
"digest-fetch": "^1.3.0",
"form-data-encoder": "1.7.2",
"formdata-node": "^4.3.2",
"node-fetch": "^2.6.7"
"node-fetch": "^2.6.7",
"web-streams-polyfill": "^3.2.1"
},
"devDependencies": {
"@types/jest": "^29.4.0",
Expand Down
4 changes: 3 additions & 1 deletion src/_shims/auto/types.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,6 @@ export declare class FsReadStream extends Readable {

// @ts-ignore
type _ReadableStream<R = any> = unknown extends ReadableStream<R> ? never : ReadableStream<R>;
export { type _ReadableStream as ReadableStream };
// @ts-ignore
declare const _ReadableStream: unknown extends typeof ReadableStream ? never : typeof ReadableStream;
export { _ReadableStream as ReadableStream };
2 changes: 2 additions & 0 deletions src/_shims/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ export type Readable = SelectType<manual.Readable, auto.Readable>;
export type FsReadStream = SelectType<manual.FsReadStream, auto.FsReadStream>;
// @ts-ignore
export type ReadableStream = SelectType<manual.ReadableStream, auto.ReadableStream>;
// @ts-ignore
export const ReadableStream: SelectType<typeof manual.ReadableStream, typeof auto.ReadableStream>;

export function getMultipartRequestOptions<T extends {} = Record<string, unknown>>(
form: FormData,
Expand Down
2 changes: 2 additions & 0 deletions src/_shims/node-runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { Readable } from 'node:stream';
import { type RequestOptions } from '../core';
import { MultipartBody } from './MultipartBody';
import { type Shims } from './registry';
import { ReadableStream } from 'web-streams-polyfill';

type FileFromPathOptions = Omit<FilePropertyBag, 'lastModified'>;

Expand Down Expand Up @@ -71,6 +72,7 @@ export function getRuntime(): Shims {
FormData: fd.FormData,
Blob: fd.Blob,
File: fd.File,
ReadableStream,
getMultipartRequestOptions,
getDefaultAgent: (url: string): Agent => (url.startsWith('https') ? defaultHttpsAgent : defaultHttpAgent),
fileFromPath,
Expand Down
2 changes: 1 addition & 1 deletion src/_shims/node-types.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import * as fd from 'formdata-node';
export { type Agent } from 'node:http';
export { type Readable } from 'node:stream';
export { type ReadStream as FsReadStream } from 'node:fs';
export { type ReadableStream } from 'web-streams-polyfill';
export { ReadableStream } from 'web-streams-polyfill';

export const fetch: typeof nf.default;

Expand Down
3 changes: 3 additions & 0 deletions src/_shims/registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export interface Shims {
FormData: any;
Blob: any;
File: any;
ReadableStream: any;
getMultipartRequestOptions: <T extends {} = Record<string, unknown>>(
form: Shims['FormData'],
opts: RequestOptions<T>,
Expand All @@ -32,6 +33,7 @@ export let Headers: Shims['Headers'] | undefined = undefined;
export let FormData: Shims['FormData'] | undefined = undefined;
export let Blob: Shims['Blob'] | undefined = undefined;
export let File: Shims['File'] | undefined = undefined;
export let ReadableStream: Shims['ReadableStream'] | undefined = undefined;
export let getMultipartRequestOptions: Shims['getMultipartRequestOptions'] | undefined = undefined;
export let getDefaultAgent: Shims['getDefaultAgent'] | undefined = undefined;
export let fileFromPath: Shims['fileFromPath'] | undefined = undefined;
Expand All @@ -55,6 +57,7 @@ export function setShims(shims: Shims, options: { auto: boolean } = { auto: fals
FormData = shims.FormData;
Blob = shims.Blob;
File = shims.File;
ReadableStream = shims.ReadableStream;
getMultipartRequestOptions = shims.getMultipartRequestOptions;
getDefaultAgent = shims.getDefaultAgent;
fileFromPath = shims.fileFromPath;
Expand Down
12 changes: 12 additions & 0 deletions src/_shims/web-runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ export function getRuntime({ manuallyImported }: { manuallyImported?: boolean }
}
}
),
ReadableStream:
// @ts-ignore
typeof ReadableStream !== 'undefined' ? ReadableStream : (
class ReadableStream {
// @ts-ignore
constructor() {
throw new Error(
`streaming isn't supported in this environment yet as 'ReadableStream' is undefined. ${recommendation}`,
);
}
}
),
getMultipartRequestOptions: async <T extends {} = Record<string, unknown>>(
// @ts-ignore
form: FormData,
Expand Down
3 changes: 2 additions & 1 deletion src/_shims/web-types.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,5 @@ export declare class FsReadStream extends Readable {
}

type _ReadableStream<R = any> = ReadableStream<R>;
export { type _ReadableStream as ReadableStream };
declare const _ReadableStream: typeof ReadableStream;
export { _ReadableStream as ReadableStream };
2 changes: 1 addition & 1 deletion src/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async function defaultParseResponse<T>(props: APIResponseProps): Promise<T> {
if (props.options.stream) {
// Note: there is an invariant here that isn't represented in the type system
// that if you set `stream: true` the response type must also be `Stream<T>`
return new Stream(response, props.controller) as any;
return Stream.fromSSEResponse(response, props.controller) as any;
}

const contentType = response.headers.get('content-type');
Expand Down
9 changes: 5 additions & 4 deletions src/error.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export class APIError extends OpenAIError {
message: string | undefined,
headers: Headers | undefined,
) {
super(`${status} ${APIError.makeMessage(error, message)}`);
super(`${APIError.makeMessage(status, error, message)}`);
this.status = status;
this.headers = headers;

Expand All @@ -30,13 +30,14 @@ export class APIError extends OpenAIError {
this.type = data?.['type'];
}

private static makeMessage(error: any, message: string | undefined) {
private static makeMessage(status: number | undefined, error: any, message: string | undefined) {
return (
error?.message ?
(status || '') +
(error?.message ?
typeof error.message === 'string' ? error.message
: JSON.stringify(error.message)
: error ? JSON.stringify(error)
: message || 'status code (no body)'
: message || 'status code (no body)')
);
}

Expand Down
2 changes: 1 addition & 1 deletion src/shims/node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ declare module '../_shims/manual-types' {
// @ts-ignore
export type FsReadStream = types.FsReadStream;
// @ts-ignore
export type ReadableStream = types.ReadableStream;
export import ReadableStream = types.ReadableStream;
}
}
2 changes: 1 addition & 1 deletion src/shims/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ declare module '../_shims/manual-types' {
// @ts-ignore
export type FsReadStream = types.FsReadStream;
// @ts-ignore
export type ReadableStream = types.ReadableStream;
export import ReadableStream = types.ReadableStream;
}
}
196 changes: 152 additions & 44 deletions src/streaming.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type Response } from './_shims/index';
import { ReadableStream, type Response } from './_shims/index';
import { OpenAIError } from './error';

type Bytes = string | ArrayBuffer | Uint8Array | Buffer | null | undefined;
Expand All @@ -12,67 +12,175 @@ type ServerSentEvent = {
export class Stream<Item> implements AsyncIterable<Item> {
controller: AbortController;

private response: Response;
private decoder: SSEDecoder;

constructor(response: Response, controller: AbortController) {
this.response = response;
constructor(private iterator: () => AsyncIterator<Item>, controller: AbortController) {
this.controller = controller;
this.decoder = new SSEDecoder();
}

private async *iterMessages(): AsyncGenerator<ServerSentEvent, void, unknown> {
if (!this.response.body) {
this.controller.abort();
throw new OpenAIError(`Attempted to iterate over a response with no body`);
}
static fromSSEResponse<Item>(response: Response, controller: AbortController) {
let consumed = false;
const decoder = new SSEDecoder();

async function* iterMessages(): AsyncGenerator<ServerSentEvent, void, unknown> {
if (!response.body) {
controller.abort();
throw new OpenAIError(`Attempted to iterate over a response with no body`);
}

const lineDecoder = new LineDecoder();

const lineDecoder = new LineDecoder();
const iter = readableStreamAsyncIterable<Bytes>(response.body);
for await (const chunk of iter) {
for (const line of lineDecoder.decode(chunk)) {
const sse = decoder.decode(line);
if (sse) yield sse;
}
}

const iter = readableStreamAsyncIterable<Bytes>(this.response.body);
for await (const chunk of iter) {
for (const line of lineDecoder.decode(chunk)) {
const sse = this.decoder.decode(line);
for (const line of lineDecoder.flush()) {
const sse = decoder.decode(line);
if (sse) yield sse;
}
}

for (const line of lineDecoder.flush()) {
const sse = this.decoder.decode(line);
if (sse) yield sse;
async function* iterator(): AsyncIterator<Item, any, undefined> {
if (consumed) {
throw new Error('Cannot iterate over a consumed stream, use `.tee()` to split the stream.');
}
consumed = true;
let done = false;
try {
for await (const sse of iterMessages()) {
if (done) continue;

if (sse.data.startsWith('[DONE]')) {
done = true;
continue;
}

if (sse.event === null) {
try {
yield JSON.parse(sse.data);
} catch (e) {
console.error(`Could not parse message into JSON:`, sse.data);
console.error(`From chunk:`, sse.raw);
throw e;
}
}
}
done = true;
} catch (e) {
// If the user calls `stream.controller.abort()`, we should exit without throwing.
if (e instanceof Error && e.name === 'AbortError') return;
throw e;
} finally {
// If the user `break`s, abort the ongoing request.
if (!done) controller.abort();
}
}

return new Stream(iterator, controller);
}

async *[Symbol.asyncIterator](): AsyncIterator<Item, any, undefined> {
let done = false;
try {
for await (const sse of this.iterMessages()) {
if (done) continue;
// Generates a Stream from a newline-separated ReadableStream where each item
// is a JSON Value.
static fromReadableStream<Item>(readableStream: ReadableStream, controller: AbortController) {
let consumed = false;

if (sse.data.startsWith('[DONE]')) {
done = true;
continue;
async function* iterLines(): AsyncGenerator<string, void, unknown> {
const lineDecoder = new LineDecoder();

const iter = readableStreamAsyncIterable<Bytes>(readableStream);
for await (const chunk of iter) {
for (const line of lineDecoder.decode(chunk)) {
yield line;
}
}

if (sse.event === null) {
try {
yield JSON.parse(sse.data);
} catch (e) {
console.error(`Could not parse message into JSON:`, sse.data);
console.error(`From chunk:`, sse.raw);
throw e;
}
for (const line of lineDecoder.flush()) {
yield line;
}
}

async function* iterator(): AsyncIterator<Item, any, undefined> {
if (consumed) {
throw new Error('Cannot iterate over a consumed stream, use `.tee()` to split the stream.');
}
consumed = true;
let done = false;
try {
for await (const line of iterLines()) {
if (done) continue;
if (line) yield JSON.parse(line);
}
done = true;
} catch (e) {
// If the user calls `stream.controller.abort()`, we should exit without throwing.
if (e instanceof Error && e.name === 'AbortError') return;
throw e;
} finally {
// If the user `break`s, abort the ongoing request.
if (!done) controller.abort();
}
done = true;
} catch (e) {
// If the user calls `stream.controller.abort()`, we should exit without throwing.
if (e instanceof Error && e.name === 'AbortError') return;
throw e;
} finally {
// If the user `break`s, abort the ongoing request.
if (!done) this.controller.abort();
}

return new Stream(iterator, controller);
}

[Symbol.asyncIterator](): AsyncIterator<Item> {
return this.iterator();
}

tee(): [Stream<Item>, Stream<Item>] {
const left: Array<Promise<IteratorResult<Item>>> = [];
const right: Array<Promise<IteratorResult<Item>>> = [];
const iterator = this.iterator();

const teeIterator = (queue: Array<Promise<IteratorResult<Item>>>): AsyncIterator<Item> => {
return {
next: () => {
if (queue.length === 0) {
const result = iterator.next();
left.push(result);
right.push(result);
}
return queue.shift()!;
},
};
};

return [
new Stream(() => teeIterator(left), this.controller),
new Stream(() => teeIterator(right), this.controller),
];
}

// Converts this stream to a newline-separated ReadableStream of JSON Stringified values in the stream
// which can be turned back into a Stream with Stream.fromReadableStream.
toReadableStream(): ReadableStream {
const self = this;
let iter: AsyncIterator<Item>;
const encoder = new TextEncoder();

return new ReadableStream({
async start() {
iter = self[Symbol.asyncIterator]();
},
async pull(ctrl) {
try {
const { value, done } = await iter.next();
if (done) return ctrl.close();

const bytes = encoder.encode(JSON.stringify(value) + '\n');

ctrl.enqueue(bytes);
} catch (err) {
ctrl.error(err);
}
},
async cancel() {
await iter.return?.();
},
});
}
}

Expand Down
5 changes: 5 additions & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4067,6 +4067,11 @@ [email protected]:
resolved "https://registry.yarnpkg.com/web-streams-polyfill/-/web-streams-polyfill-4.0.0-beta.1.tgz#3b19b9817374b7cee06d374ba7eeb3aeb80e8c95"
integrity sha512-3ux37gEX670UUphBF9AMCq8XM6iQ8Ac6A+DSRRjDoRBm1ufCkaCDdNVbaqq60PsEkdNlLKrGtv/YBP4EJXqNtQ==

web-streams-polyfill@^3.2.1:
version "3.2.1"
resolved "https://registry.yarnpkg.com/web-streams-polyfill/-/web-streams-polyfill-3.2.1.tgz#71c2718c52b45fd49dbeee88634b3a60ceab42a6"
integrity sha512-e0MO3wdXWKrLbL0DgGnUV7WHVuw9OUvL4hjgnPkIeEvESk74gAITi5G606JtZPp39cd8HA9VQzCIvA49LpPN5Q==

webidl-conversions@^3.0.0:
version "3.0.1"
resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-3.0.1.tgz#24534275e2a7bc6be7bc86611cc16ae0a5654871"
Expand Down

0 comments on commit 71984ed

Please sign in to comment.