diff --git a/lib/dispatcher/client-h2.js b/lib/dispatcher/client-h2.js index ba9157cab5b..fffc0f84ea1 100644 --- a/lib/dispatcher/client-h2.js +++ b/lib/dispatcher/client-h2.js @@ -929,7 +929,6 @@ function writeH2 (client, request) { stream[kHTTP2Stream] = true stream[kRequestStreamState] = state state.stream = stream - bindRequestToStream(request, stream, null) // Increment counter as we have new streams open ++session[kOpenStreams] diff --git a/test/http2-dispatcher.js b/test/http2-dispatcher.js index bbdd2e112a3..262f7001afe 100644 --- a/test/http2-dispatcher.js +++ b/test/http2-dispatcher.js @@ -673,6 +673,105 @@ test('Should clear h2 request stream references before completing a response', a await t.completed }) +test('Should clear h2 request stream references after abort before response', async t => { + t = tspl(t, { plan: 5 }) + + const server = createSecureServer(await pem.generate({ opts: { keySize: 2048 } })) + + server.on('stream', (stream) => { + stream.on('error', () => {}) + }) + + after(() => server.close()) + await once(server.listen(0), 'listening') + + const client = new Client(`https://localhost:${server.address().port}`, { + connect: { + rejectUnauthorized: false + }, + allowH2: true + }) + after(() => client.close()) + + const abortReason = new Error('abort after h2 stream assigned') + let abortRequest = null + + const waitFor = async (fn) => { + for (let i = 0; i < 100; i++) { + const value = fn() + if (value != null) { + return value + } + await sleep(10) + } + throw new Error('timed out waiting for h2 request stream') + } + + const responseError = new Promise(resolve => { + client.dispatch({ + path: '/', + method: 'GET' + }, { + onRequestStart (controller) { + abortRequest = controller.abort.bind(controller) + }, + onResponseStart () { + t.fail('unexpected response') + }, + onResponseData () { + return true + }, + onResponseEnd () { + t.fail('unexpected response end') + }, + onResponseError (_controller, err) { + resolve(err) + } + }) + }) + + const { + request, + requestStreamIdSymbol, + requestStreamSymbol, + requestStreamCleanupSymbol + } = await waitFor(() => { + const request = client[kQueue][client[kRunningIdx]] + if (request == null) { + return null + } + + const symbols = Object.getOwnPropertySymbols(request) + const requestStreamIdSymbol = symbols.find((symbol) => symbol.description === 'request stream id') + const requestStreamSymbol = symbols.find((symbol) => symbol.description === 'request stream') + const requestStreamCleanupSymbol = symbols.find((symbol) => symbol.description === 'request stream cleanup') + + if (requestStreamIdSymbol == null || requestStreamSymbol == null || requestStreamCleanupSymbol == null) { + return null + } + + return { + request, + requestStreamIdSymbol, + requestStreamSymbol, + requestStreamCleanupSymbol + } + }) + + t.ok(await waitFor(() => request[requestStreamSymbol])) + + abortRequest = await waitFor(() => abortRequest) + abortRequest(abortReason) + + const err = await responseError + t.strictEqual(err, abortReason) + t.strictEqual(request[requestStreamIdSymbol], null) + t.strictEqual(request[requestStreamSymbol], null) + t.strictEqual(request[requestStreamCleanupSymbol], null) + + await t.completed +}) + test('Should only accept valid ping interval values', async t => { const planner = tspl(t, { plan: 3 })