diff --git a/docs/docs/api/Dispatcher.md b/docs/docs/api/Dispatcher.md index 933f4a730f8..c44abf5231b 100644 --- a/docs/docs/api/Dispatcher.md +++ b/docs/docs/api/Dispatcher.md @@ -985,6 +985,57 @@ client.dispatch( ); ``` +##### `dns` + +The `dns` interceptor enables you to cache DNS lookups for a given duration, per origin. + +>It is well suited for scenarios where you want to cache DNS lookups to avoid the overhead of resolving the same domain multiple times + +**Options** +- `maxTTL` - The maximum time-to-live (in milliseconds) of the DNS cache. It should be a positive integer. Default: `10000`. + - Set `0` to disable TTL. +- `maxItems` - The maximum number of items to cache. It should be a positive integer. Default: `Infinity`. +- `dualStack` - Whether to resolve both IPv4 and IPv6 addresses. Default: `true`. + - It will also attempt a happy-eyeballs-like approach to connect to the available addresses in case of a connection failure. +- `affinity` - Whether to use IPv4 or IPv6 addresses. Default: `4`. + - It can be either `'4` or `6`. + - It will only take effect if `dualStack` is `false`. +- `lookup: (hostname: string, options: LookupOptions, callback: (err: NodeJS.ErrnoException | null, addresses: DNSInterceptorRecord[]) => void) => void` - Custom lookup function. Default: `dns.lookup`. + - For more info see [dns.lookup](https://nodejs.org/api/dns.html#dns_dns_lookup_hostname_options_callback). +- `pick: (origin: URL, records: DNSInterceptorRecords, affinity: 4 | 6) => DNSInterceptorRecord` - Custom pick function. Default: `RoundRobin`. + - The function should return a single record from the records array. + - By default a simplified version of Round Robin is used. + - The `records` property can be mutated to store the state of the balancing algorithm. + +> The `Dispatcher#options` also gets extended with the options `dns.affinity`, `dns.dualStack`, `dns.lookup` and `dns.pick` which can be used to configure the interceptor at a request-per-request basis. + + +**DNSInterceptorRecord** +It represents a DNS record. +- `family` - (`number`) The IP family of the address. It can be either `4` or `6`. +- `address` - (`string`) The IP address. + +**DNSInterceptorOriginRecords** +It represents a map of DNS IP addresses records for a single origin. +- `4.ips` - (`DNSInterceptorRecord[] | null`) The IPv4 addresses. +- `6.ips` - (`DNSInterceptorRecord[] | null`) The IPv6 addresses. + +**Example - Basic DNS Interceptor** + +```js +const { Client, interceptors } = require("undici"); +const { dns } = interceptors; + +const client = new Agent().compose([ + dns({ ...opts }) +]) + +const response = await client.request({ + origin: `http://localhost:3030`, + ...requestOpts +}) +``` + ##### `Response Error Interceptor` **Introduction** diff --git a/index.js b/index.js index 444706560ae..090f6cfc892 100644 --- a/index.js +++ b/index.js @@ -39,7 +39,8 @@ module.exports.RedirectHandler = RedirectHandler module.exports.interceptors = { redirect: require('./lib/interceptor/redirect'), retry: require('./lib/interceptor/retry'), - dump: require('./lib/interceptor/dump') + dump: require('./lib/interceptor/dump'), + dns: require('./lib/interceptor/dns') } module.exports.buildConnector = buildConnector diff --git a/lib/interceptor/dns.js b/lib/interceptor/dns.js new file mode 100644 index 00000000000..e0e1d7fa486 --- /dev/null +++ b/lib/interceptor/dns.js @@ -0,0 +1,346 @@ +'use strict' +const { isIP } = require('node:net') +const { lookup } = require('node:dns') +const DecoratorHandler = require('../handler/decorator-handler') +const { InvalidArgumentError, InformationalError } = require('../core/errors') +const maxInt = Math.pow(2, 31) - 1 + +class DNSInstance { + #maxTTL = 0 + #maxItems = 0 + #records = new Map() + dualStack = true + affinity = null + lookup = null + pick = null + lastIpFamily = null + + constructor (opts) { + this.#maxTTL = opts.maxTTL + this.#maxItems = opts.maxItems + this.dualStack = opts.dualStack + this.affinity = opts.affinity + this.lookup = opts.lookup ?? this.#defaultLookup + this.pick = opts.pick ?? this.#defaultPick + } + + get full () { + return this.#records.size === this.#maxItems + } + + runLookup (origin, opts, cb) { + const ips = this.#records.get(origin.hostname) + + // If full, we just return the origin + if (ips == null && this.full) { + cb(null, origin.origin) + return + } + + const newOpts = { + affinity: this.affinity, + dualStack: this.dualStack, + lookup: this.lookup, + pick: this.pick, + ...opts.dns, + maxTTL: this.#maxTTL, + maxItems: this.#maxItems + } + + // If no IPs we lookup + if (ips == null) { + this.lookup(origin, newOpts, (err, addresses) => { + if (err || addresses == null || addresses.length === 0) { + cb(err ?? new InformationalError('No DNS entries found')) + return + } + + this.setRecords(origin, addresses) + const records = this.#records.get(origin.hostname) + + const ip = this.pick( + origin, + records, + // Only set affinity if dual stack is disabled + // otherwise let it go through normal flow + !newOpts.dualStack && newOpts.affinity + ) + + cb( + null, + `${origin.protocol}//${ + ip.family === 6 ? `[${ip.address}]` : ip.address + }${origin.port === '' ? '' : `:${origin.port}`}` + ) + }) + } else { + // If there's IPs we pick + const ip = this.pick( + origin, + ips, + // Only set affinity if dual stack is disabled + // otherwise let it go through normal flow + !newOpts.dualStack && newOpts.affinity + ) + + // If no IPs we lookup - deleting old records + if (ip == null) { + this.#records.delete(origin.hostname) + this.runLookup(origin, opts, cb) + return + } + + cb( + null, + `${origin.protocol}//${ + ip.family === 6 ? `[${ip.address}]` : ip.address + }${origin.port === '' ? '' : `:${origin.port}`}` + ) + } + } + + #defaultLookup (origin, opts, cb) { + lookup( + origin.hostname, + { all: true, family: this.dualStack === false ? this.affinity : 0 }, + (err, addresses) => { + if (err) { + return cb(err) + } + + const results = new Map() + + for (const addr of addresses) { + const record = { + address: addr.address, + ttl: opts.maxTTL, + family: addr.family + } + + // On linux we found duplicates, we attempt to remove them with + // the latest record + results.set(`${record.address}:${record.family}`, record) + } + + cb(null, results.values()) + } + ) + } + + #defaultPick (origin, hostnameRecords, affinity) { + let ip = null + const { records, offset = 0 } = hostnameRecords + let newOffset = 0 + + if (offset === maxInt) { + newOffset = 0 + } else { + newOffset = offset + 1 + } + + // We balance between the two IP families + // If dual-stack disabled, we automatically pick the affinity + const newIpFamily = (newOffset & 1) === 1 ? 4 : 6 + const family = + this.dualStack === false + ? records[this.affinity] // If dual-stack is disabled, we pick the default affiniy + : records[affinity] ?? records[newIpFamily] + + // If no IPs and we have tried both families or dual stack is disabled, we return null + if ( + (family == null || family.ips.length === 0) && + // eslint-disable-next-line eqeqeq + (this.dualStack === false || this.lastIpFamily != newIpFamily) + ) { + return ip + } + + family.offset = family.offset ?? 0 + hostnameRecords.offset = newOffset + + if (family.offset === maxInt) { + family.offset = 0 + } else { + family.offset++ + } + + const position = family.offset % family.ips.length + ip = family.ips[position] ?? null + + if (ip == null) { + return ip + } + + const timestamp = Date.now() + // Record TTL is already in ms + if (ip.timestamp != null && timestamp - ip.timestamp > ip.ttl) { + // We delete expired records + // It is possible that they have different TTL, so we manage them individually + family.ips.splice(position, 1) + return this.pick(origin, hostnameRecords, affinity) + } + + ip.timestamp = timestamp + + this.lastIpFamily = newIpFamily + return ip + } + + setRecords (origin, addresses) { + const records = { records: { 4: null, 6: null } } + for (const record of addresses) { + const familyRecords = records.records[record.family] ?? { ips: [] } + + familyRecords.ips.push(record) + records.records[record.family] = familyRecords + } + + this.#records.set(origin.hostname, records) + } + + getHandler (meta, opts) { + return new DNSDispatchHandler(this, meta, opts) + } +} + +class DNSDispatchHandler extends DecoratorHandler { + #state = null + #opts = null + #dispatch = null + #handler = null + #origin = null + + constructor (state, { origin, handler, dispatch }, opts) { + super(handler) + this.#origin = origin + this.#handler = handler + this.#opts = { ...opts } + this.#state = state + this.#dispatch = dispatch + } + + onError (err) { + switch (err.code) { + case 'ETIMEDOUT': + case 'ECONNREFUSED': { + if (this.#state.dualStack) { + // We delete the record and retry + this.#state.runLookup(this.#origin, this.#opts, (err, newOrigin) => { + if (err) { + return this.#handler.onError(err) + } + + const dispatchOpts = { + ...this.#opts, + origin: newOrigin + } + + this.#dispatch(dispatchOpts, this) + }) + + // if dual-stack disabled, we error out + return + } + + this.#handler.onError(err) + return + } + case 'ENOTFOUND': + this.#state.deleteRecord(this.#origin) + // eslint-disable-next-line no-fallthrough + default: + this.#handler.onError(err) + break + } + } +} + +module.exports = interceptorOpts => { + if ( + interceptorOpts?.maxTTL != null && + (typeof interceptorOpts?.maxTTL !== 'number' || interceptorOpts?.maxTTL < 0) + ) { + throw new InvalidArgumentError('Invalid maxTTL. Must be a positive number') + } + + if ( + interceptorOpts?.maxItems != null && + (typeof interceptorOpts?.maxItems !== 'number' || + interceptorOpts?.maxItems < 1) + ) { + throw new InvalidArgumentError( + 'Invalid maxItems. Must be a positive number and greater than zero' + ) + } + + if ( + interceptorOpts?.affinity != null && + interceptorOpts?.affinity !== 4 && + interceptorOpts?.affinity !== 6 + ) { + throw new InvalidArgumentError('Invalid affinity. Must be either 4 or 6') + } + + if ( + interceptorOpts?.dualStack != null && + typeof interceptorOpts?.dualStack !== 'boolean' + ) { + throw new InvalidArgumentError('Invalid dualStack. Must be a boolean') + } + + if ( + interceptorOpts?.lookup != null && + typeof interceptorOpts?.lookup !== 'function' + ) { + throw new InvalidArgumentError('Invalid lookup. Must be a function') + } + + if ( + interceptorOpts?.pick != null && + typeof interceptorOpts?.pick !== 'function' + ) { + throw new InvalidArgumentError('Invalid pick. Must be a function') + } + + const opts = { + maxTTL: interceptorOpts?.maxTTL ?? 10e3, // Expressed in ms + lookup: interceptorOpts?.lookup ?? null, + pick: interceptorOpts?.pick ?? null, + dualStack: interceptorOpts?.dualStack ?? true, + affinity: interceptorOpts?.affinity ?? 4, + maxItems: interceptorOpts?.maxItems ?? Infinity + } + + const instance = new DNSInstance(opts) + + return dispatch => { + return function dnsInterceptor (origDispatchOpts, handler) { + const origin = + origDispatchOpts.origin.constructor === URL + ? origDispatchOpts.origin + : new URL(origDispatchOpts.origin) + + if (isIP(origin.hostname) !== 0) { + return dispatch(origDispatchOpts, handler) + } + + instance.runLookup(origin, origDispatchOpts, (err, newOrigin) => { + if (err) { + return handler.onError(err) + } + + const dispatchOpts = { + ...origDispatchOpts, + origin: newOrigin + } + + dispatch( + dispatchOpts, + instance.getHandler({ origin, dispatch, handler }, origDispatchOpts) + ) + }) + + return true + } + } +} diff --git a/test/interceptors/dns.js b/test/interceptors/dns.js new file mode 100644 index 00000000000..e58a1d597ba --- /dev/null +++ b/test/interceptors/dns.js @@ -0,0 +1,803 @@ +'use strict' + +const { test, after } = require('node:test') +const { isIP } = require('node:net') +const { lookup } = require('node:dns') +const { createServer } = require('node:http') +const { once } = require('node:events') +const { setTimeout: sleep } = require('node:timers/promises') + +const { tspl } = require('@matteo.collina/tspl') + +const { interceptors, Agent } = require('../..') +const { dns } = interceptors + +const isWindows = process.platform === 'win32' + +test('Should validate options', t => { + t = tspl(t, { plan: 10 }) + + t.throws(() => dns({ dualStack: 'true' }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ dualStack: 0 }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ affinity: '4' }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ affinity: 7 }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ maxTTL: -1 }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ maxTTL: '0' }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ maxItems: '1' }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ maxItems: -1 }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ lookup: {} }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ pick: [] }), { code: 'UND_ERR_INVALID_ARG' }) +}) + +test('Should automatically resolve IPs (dual stack)', async t => { + t = tspl(t, { plan: 6 }) + + let counter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0) + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(isIP(url.hostname), 4) + break + + case 2: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns() + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') +}) + +test('Should recover on network errors (dual stack - 4)', async t => { + t = tspl(t, { plan: 8 }) + + let counter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0, '::1') + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(isIP(url.hostname), 4) + break + + case 2: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + + case 3: + // [::1] -> ::1 + t.equal(isIP(url.hostname), 4) + break + + case 4: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns() + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') +}) + +test('Should recover on network errors (dual stack - 6)', async t => { + t = tspl(t, { plan: 7 }) + + let counter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0, '127.0.0.1') + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(isIP(url.hostname), 4) + break + + case 2: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + + case 3: + // [::1] -> ::1 + t.equal(isIP(url.hostname), 4) + break + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns() + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') +}) + +test('Should throw when on dual-stack disabled (4)', async t => { + t = tspl(t, { plan: 2 }) + + let counter = 0 + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(isIP(url.hostname), 4) + break + + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns({ dualStack: false, affinity: 4 }) + ]) + + const promise = client.request({ + ...requestOptions, + origin: 'http://localhost:1234' + }) + + await t.rejects(promise, 'ECONNREFUSED') + + await t.complete +}) + +test('Should throw when on dual-stack disabled (6)', async t => { + t = tspl(t, { plan: 2 }) + + let counter = 0 + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns({ dualStack: false, affinity: 6 }) + ]) + + // Note: In windows the IPV6 does not results in ECONNREFUSED + // but rather in TIMEOUT + if (isWindows) { + const promise = client.request({ + ...requestOptions, + origin: 'http://localhost', + headersTimeout: 500 + }) + + await t.rejects(promise, 'UND_ERR_HEADERS_TIMEOUT') + } else { + const promise = client.request({ + ...requestOptions, + origin: 'http://localhost' + }) + + await t.rejects(promise, 'ECONNREFUSED') + } + + await t.complete +}) + +test('Should automatically resolve IPs (dual stack disabled - 4)', async t => { + t = tspl(t, { plan: 6 }) + + let counter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0) + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(isIP(url.hostname), 4) + break + + case 2: + // [::1] -> ::1 + t.equal(isIP(url.hostname), 4) + break + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns({ dualStack: false }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') +}) + +test('Should automatically resolve IPs (dual stack disabled - 6)', async t => { + t = tspl(t, { plan: 6 }) + + let counter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0) + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + + case 2: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns({ dualStack: false, affinity: 6 }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') +}) + +test('Should we handle TTL (4)', async t => { + t = tspl(t, { plan: 7 }) + + let counter = 0 + let lookupCounter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0, '127.0.0.1') + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(isIP(url.hostname), 4) + break + + case 2: + t.equal(isIP(url.hostname), 4) + break + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns({ + dualStack: false, + affinity: 4, + maxTTL: 100, + lookup: (origin, opts, cb) => { + ++lookupCounter + lookup( + origin.hostname, + { all: true, family: opts.affinity }, + (err, addresses) => { + if (err) { + return cb(err) + } + + const results = new Map() + + for (const addr of addresses) { + const record = { + address: addr.address, + ttl: opts.maxTTL, + family: addr.family + } + + results.set(`${record.address}:${record.family}`, record) + } + + cb(null, results.values()) + } + ) + } + }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + await sleep(500) + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') + t.equal(lookupCounter, 2) +}) + +test('Should we handle TTL (6)', async t => { + t = tspl(t, { plan: 7 }) + + let counter = 0 + let lookupCounter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0, '::1') + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + + case 2: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns({ + dualStack: false, + affinity: 6, + maxTTL: 100, + lookup: (origin, opts, cb) => { + ++lookupCounter + lookup( + origin.hostname, + { all: true, family: opts.affinity }, + (err, addresses) => { + if (err) { + return cb(err) + } + + const results = [] + + for (const addr of addresses) { + const record = { + address: addr.address, + ttl: opts.maxTTL, + family: addr.family + } + + results.push(record) + } + + cb(null, results) + } + ) + } + }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + await sleep(200) + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') + t.equal(lookupCounter, 2) +}) + +test('Should handle max cached items', async t => { + t = tspl(t, { plan: 9 }) + + let counter = 0 + const server1 = createServer() + const server2 = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server1.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server1.listen(0) + + server2.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world! (x2)') + }) + server2.listen(0) + + await Promise.all([once(server1, 'listening'), once(server2, 'listening')]) + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(isIP(url.hostname), 4) + break + + case 2: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + + case 3: + t.equal(url.hostname, 'developer.mozilla.org') + // Rewrite origin to avoid reaching internet + opts.origin = `http://127.0.0.1:${server2.address().port}` + break + default: + t.fails('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns({ maxItems: 1 }) + ]) + + after(async () => { + await client.close() + server1.close() + server2.close() + + await Promise.all([once(server1, 'close'), once(server2, 'close')]) + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server1.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server1.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') + + const response3 = await client.request({ + ...requestOptions, + origin: 'https://developer.mozilla.org' + }) + + t.equal(response3.statusCode, 200) + t.equal(await response3.body.text(), 'hello world! (x2)') +}) diff --git a/types/interceptors.d.ts b/types/interceptors.d.ts index 53835e01299..6fc50fb8dc1 100644 --- a/types/interceptors.d.ts +++ b/types/interceptors.d.ts @@ -1,5 +1,6 @@ import Dispatcher from './dispatcher' import RetryHandler from './retry-handler' +import { LookupOptions } from 'node:dns' export default Interceptors @@ -7,11 +8,24 @@ declare namespace Interceptors { export type DumpInterceptorOpts = { maxSize?: number } export type RetryInterceptorOpts = RetryHandler.RetryOptions export type RedirectInterceptorOpts = { maxRedirections?: number } + export type ResponseErrorInterceptorOpts = { throwOnError: boolean } + // DNS interceptor + export type DNSInterceptorRecord = { address: string, ttl: number, family: 4 | 6 } + export type DNSInterceptorOriginRecords = { 4: { ips: DNSInterceptorRecord[] } | null, 6: { ips: DNSInterceptorRecord[] } | null } + export type DNSInterceptorOpts = { + maxTTL?: number + maxItems?: number + lookup?: (hostname: string, options: LookupOptions, callback: (err: NodeJS.ErrnoException | null, addresses: DNSInterceptorRecord[]) => void) => void + pick?: (origin: URL, records: DNSInterceptorOriginRecords, affinity: 4 | 6) => DNSInterceptorRecord + dualStack?: boolean + affinity?: 4 | 6 + } export function createRedirectInterceptor (opts: RedirectInterceptorOpts): Dispatcher.DispatcherComposeInterceptor export function dump (opts?: DumpInterceptorOpts): Dispatcher.DispatcherComposeInterceptor export function retry (opts?: RetryInterceptorOpts): Dispatcher.DispatcherComposeInterceptor export function redirect (opts?: RedirectInterceptorOpts): Dispatcher.DispatcherComposeInterceptor export function responseError (opts?: ResponseErrorInterceptorOpts): Dispatcher.DispatcherComposeInterceptor + export function dns (opts?: DNSInterceptorOpts): Dispatcher.DispatcherComposeInterceptor }