diff --git a/src/subscription/__tests__/flattenAsyncIterator-test.ts b/src/subscription/__tests__/flattenAsyncIterator-test.ts new file mode 100644 index 0000000000..ece7402157 --- /dev/null +++ b/src/subscription/__tests__/flattenAsyncIterator-test.ts @@ -0,0 +1,141 @@ +import { expect } from 'chai'; +import { describe, it } from 'mocha'; + +import { flattenAsyncIterator } from '../flattenAsyncIterator'; + +describe('flattenAsyncIterator', () => { + it('does not modify an already flat async generator', async () => { + async function* source() { + yield await Promise.resolve(1); + yield await Promise.resolve(2); + yield await Promise.resolve(3); + } + + const result = flattenAsyncIterator(source()); + + expect(await result.next()).to.deep.equal({ value: 1, done: false }); + expect(await result.next()).to.deep.equal({ value: 2, done: false }); + expect(await result.next()).to.deep.equal({ value: 3, done: false }); + expect(await result.next()).to.deep.equal({ + value: undefined, + done: true, + }); + }); + + it('does not modify an already flat async iterator', async () => { + const items = [1, 2, 3]; + + const iterator: any = { + [Symbol.asyncIterator]() { + return this; + }, + next() { + return Promise.resolve({ + done: items.length === 0, + value: items.shift(), + }); + }, + }; + + const result = flattenAsyncIterator(iterator); + + expect(await result.next()).to.deep.equal({ value: 1, done: false }); + expect(await result.next()).to.deep.equal({ value: 2, done: false }); + expect(await result.next()).to.deep.equal({ value: 3, done: false }); + expect(await result.next()).to.deep.equal({ + value: undefined, + done: true, + }); + }); + + it('flatten nested async generators', async () => { + async function* source() { + yield await Promise.resolve(1); + yield await Promise.resolve(2); + yield await Promise.resolve( + (async function* nested(): AsyncGenerator { + yield await Promise.resolve(2.1); + yield await Promise.resolve(2.2); + })(), + ); + yield await Promise.resolve(3); + } + + const doubles = flattenAsyncIterator(source()); + + const result = []; + for await (const x of doubles) { + result.push(x); + } + expect(result).to.deep.equal([1, 2, 2.1, 2.2, 3]); + }); + + it('allows returning early from a nested async generator', async () => { + async function* source() { + yield await Promise.resolve(1); + yield await Promise.resolve(2); + yield await Promise.resolve( + (async function* nested(): AsyncGenerator { + yield await Promise.resolve(2.1); + // istanbul ignore next (Shouldn't be reached) + yield await Promise.resolve(2.2); + })(), + ); + // istanbul ignore next (Shouldn't be reached) + yield await Promise.resolve(3); + } + + const doubles = flattenAsyncIterator(source()); + + expect(await doubles.next()).to.deep.equal({ value: 1, done: false }); + expect(await doubles.next()).to.deep.equal({ value: 2, done: false }); + expect(await doubles.next()).to.deep.equal({ value: 2.1, done: false }); + + // Early return + expect(await doubles.return()).to.deep.equal({ + value: undefined, + done: true, + }); + + // Subsequent next calls + expect(await doubles.next()).to.deep.equal({ + value: undefined, + done: true, + }); + expect(await doubles.next()).to.deep.equal({ + value: undefined, + done: true, + }); + }); + + it('allows throwing errors from a nested async generator', async () => { + async function* source() { + yield await Promise.resolve(1); + yield await Promise.resolve(2); + yield await Promise.resolve( + (async function* nested(): AsyncGenerator { + yield await Promise.resolve(2.1); + // istanbul ignore next (Shouldn't be reached) + yield await Promise.resolve(2.2); + })(), + ); + // istanbul ignore next (Shouldn't be reached) + yield await Promise.resolve(3); + } + + const doubles = flattenAsyncIterator(source()); + + expect(await doubles.next()).to.deep.equal({ value: 1, done: false }); + expect(await doubles.next()).to.deep.equal({ value: 2, done: false }); + expect(await doubles.next()).to.deep.equal({ value: 2.1, done: false }); + + // Throw error + let caughtError; + try { + await doubles.throw('ouch'); + } catch (e) { + caughtError = e; + } + expect(caughtError).to.equal('ouch'); + }); +}); diff --git a/src/subscription/__tests__/subscribe-test.ts b/src/subscription/__tests__/subscribe-test.ts index 85d6400a3d..067066e5a3 100644 --- a/src/subscription/__tests__/subscribe-test.ts +++ b/src/subscription/__tests__/subscribe-test.ts @@ -79,17 +79,22 @@ const emailSchema = new GraphQLSchema({ }), }); -function createSubscription(pubsub: SimplePubSub) { +function createSubscription( + pubsub: SimplePubSub, + variableValues?: { readonly [variable: string]: unknown }, +) { const document = parse(` - subscription ($priority: Int = 0) { + subscription ($priority: Int = 0, $shouldDefer: Boolean = false) { importantEmail(priority: $priority) { email { from subject } - inbox { - unread - total + ... @defer(if: $shouldDefer) { + inbox { + unread + total + } } } } @@ -119,7 +124,12 @@ function createSubscription(pubsub: SimplePubSub) { }), }; - return subscribe({ schema: emailSchema, document, rootValue: data }); + return subscribe({ + schema: emailSchema, + document, + rootValue: data, + variableValues, + }); } async function expectPromise(promise: Promise) { @@ -616,6 +626,136 @@ describe('Subscription Publish Phase', () => { }); }); + it('produces additional payloads for subscriptions with @defer', async () => { + const pubsub = new SimplePubSub(); + const subscription = await createSubscription(pubsub, { + shouldDefer: true, + }); + invariant(isAsyncIterable(subscription)); + // Wait for the next subscription payload. + const payload = subscription.next(); + + // A new email arrives! + expect( + pubsub.emit({ + from: 'yuzhi@graphql.org', + subject: 'Alright', + message: 'Tests are good', + unread: true, + }), + ).to.equal(true); + + // The previously waited on payload now has a value. + expect(await payload).to.deep.equal({ + done: false, + value: { + data: { + importantEmail: { + email: { + from: 'yuzhi@graphql.org', + subject: 'Alright', + }, + }, + }, + hasNext: true, + }, + }); + + // Wait for the next payload from @defer + expect(await subscription.next()).to.deep.equal({ + done: false, + value: { + data: { + inbox: { + unread: 1, + total: 2, + }, + }, + path: ['importantEmail'], + hasNext: false, + }, + }); + + // Another new email arrives, after all incrementally delivered payloads are received. + expect( + pubsub.emit({ + from: 'hyo@graphql.org', + subject: 'Tools', + message: 'I <3 making things', + unread: true, + }), + ).to.equal(true); + + // The next waited on payload will have a value. + expect(await subscription.next()).to.deep.equal({ + done: false, + value: { + data: { + importantEmail: { + email: { + from: 'hyo@graphql.org', + subject: 'Tools', + }, + }, + }, + hasNext: true, + }, + }); + + // Another new email arrives, before the incrementally delivered payloads from the last email was received. + expect( + pubsub.emit({ + from: 'adam@graphql.org', + subject: 'Important', + message: 'Read me please', + unread: true, + }), + ).to.equal(true); + + // Deferred payload from previous event is received. + expect(await subscription.next()).to.deep.equal({ + done: false, + value: { + data: { + inbox: { + unread: 2, + total: 3, + }, + }, + path: ['importantEmail'], + hasNext: false, + }, + }); + + // Next payload from last event + expect(await subscription.next()).to.deep.equal({ + done: false, + value: { + data: { + importantEmail: { + email: { + from: 'adam@graphql.org', + subject: 'Important', + }, + }, + }, + hasNext: true, + }, + }); + + // The client disconnects before the deferred payload is consumed. + expect(await subscription.return()).to.deep.equal({ + done: true, + value: undefined, + }); + + // Awaiting a subscription after closing it results in completed results. + expect(await subscription.next()).to.deep.equal({ + done: true, + value: undefined, + }); + }); + it('produces a payload when there are multiple events', async () => { const pubsub = new SimplePubSub(); const subscription = await createSubscription(pubsub); diff --git a/src/subscription/flattenAsyncIterator.ts b/src/subscription/flattenAsyncIterator.ts new file mode 100644 index 0000000000..1533482bb9 --- /dev/null +++ b/src/subscription/flattenAsyncIterator.ts @@ -0,0 +1,50 @@ +import { isAsyncIterable } from '../jsutils/isAsyncIterable'; + +type AsyncIterableOrGenerator = + | AsyncGenerator + | AsyncIterable; + +/** + * Given an AsyncIterable that could potentially yield other async iterators, + * flatten all yielded results into a single AsyncIterable + */ +export function flattenAsyncIterator( + iterable: AsyncIterableOrGenerator>, +): AsyncGenerator { + const iteratorMethod = iterable[Symbol.asyncIterator]; + const iterator: any = iteratorMethod.call(iterable); + let iteratorStack: Array> = [iterator]; + + async function next(): Promise> { + const currentIterator = iteratorStack[0]; + if (!currentIterator) { + return { value: undefined, done: true }; + } + const result = await currentIterator.next(); + if (result.done) { + iteratorStack.shift(); + return next(); + } else if (isAsyncIterable(result.value)) { + const childIterator = result.value[ + Symbol.asyncIterator + ]() as AsyncIterator; + iteratorStack.unshift(childIterator); + return next(); + } + return result; + } + return { + next, + return() { + iteratorStack = []; + return iterator.return(); + }, + throw(error?: unknown): Promise> { + iteratorStack = []; + return iterator.throw(error); + }, + [Symbol.asyncIterator]() { + return this; + }, + }; +} diff --git a/src/subscription/subscribe.ts b/src/subscription/subscribe.ts index daf9b8fdb2..c5b5c4b6b2 100644 --- a/src/subscription/subscribe.ts +++ b/src/subscription/subscribe.ts @@ -1,7 +1,6 @@ import { inspect } from '../jsutils/inspect'; import { isAsyncIterable } from '../jsutils/isAsyncIterable'; import { addPath, pathToArray } from '../jsutils/Path'; -import type { PromiseOrValue } from '../jsutils/PromiseOrValue'; import type { Maybe } from '../jsutils/Maybe'; import { GraphQLError } from '../error/GraphQLError'; @@ -9,7 +8,11 @@ import { locatedError } from '../error/locatedError'; import type { DocumentNode } from '../language/ast'; -import type { ExecutionResult, ExecutionContext } from '../execution/execute'; +import type { + ExecutionResult, + ExecutionContext, + AsyncExecutionResult, +} from '../execution/execute'; import { getArgumentValues } from '../execution/values'; import { assertValidExecutionArguments, @@ -26,6 +29,7 @@ import type { GraphQLFieldResolver } from '../type/definition'; import { getOperationRootType } from '../utilities/getOperationRootType'; import { mapAsyncIterator } from './mapAsyncIterator'; +import { flattenAsyncIterator } from './flattenAsyncIterator'; export interface SubscriptionArgs { schema: GraphQLSchema; @@ -61,7 +65,10 @@ export interface SubscriptionArgs { */ export async function subscribe( args: SubscriptionArgs, -): Promise | ExecutionResult> { +): Promise< + | AsyncGenerator + | ExecutionResult +> { const { schema, document, @@ -93,8 +100,8 @@ export async function subscribe( // the GraphQL specification. The `execute` function provides the // "ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the // "ExecuteQuery" algorithm, for which `execute` is also used. - const mapSourceToResponse = (payload: unknown) => { - const executionResult = execute({ + const mapSourceToResponse = (payload: unknown) => + execute({ schema, document, rootValue: payload, @@ -103,17 +110,11 @@ export async function subscribe( operationName, fieldResolver, }); - /* istanbul ignore if - TODO: implement support for defer/stream in subscriptions */ - if (isAsyncIterable(executionResult)) { - throw new Error( - 'TODO: implement support for defer/stream in subscriptions', - ); - } - return executionResult as PromiseOrValue; - }; // Map every source value to a ExecutionResult value as described above. - return mapAsyncIterator(resultOrStream, mapSourceToResponse); + return flattenAsyncIterator( + mapAsyncIterator(resultOrStream, mapSourceToResponse), + ); } /**