diff --git a/packages/toolkit/src/query/core/apiState.ts b/packages/toolkit/src/query/core/apiState.ts index 7e1c1420e0..91cccf9ca0 100644 --- a/packages/toolkit/src/query/core/apiState.ts +++ b/packages/toolkit/src/query/core/apiState.ts @@ -148,6 +148,7 @@ export type SubscriptionOptions = { */ refetchOnFocus?: boolean } +export type SubscribersInternal = Map export type Subscribers = { [requestId: string]: SubscriptionOptions } export type QueryKeys = { [K in keyof Definitions]: Definitions[K] extends QueryDefinition< @@ -327,6 +328,8 @@ export type QueryState = { | undefined } +export type SubscriptionInternalState = Map + export type SubscriptionState = { [queryCacheKey: string]: Subscribers | undefined } diff --git a/packages/toolkit/src/query/core/buildInitiate.ts b/packages/toolkit/src/query/core/buildInitiate.ts index f5af4d8667..54ba080373 100644 --- a/packages/toolkit/src/query/core/buildInitiate.ts +++ b/packages/toolkit/src/query/core/buildInitiate.ts @@ -42,6 +42,7 @@ import type { ThunkApiMetaConfig, } from './buildThunks' import type { ApiEndpointQuery } from './module' +import type { InternalMiddlewareState } from './buildMiddleware/types' export type BuildInitiateApiEndpointQuery< Definition extends QueryDefinition, @@ -270,6 +271,7 @@ export function buildInitiate({ mutationThunk, api, context, + internalState, }: { serializeQueryArgs: InternalSerializeQueryArgs queryThunk: QueryThunk @@ -277,20 +279,9 @@ export function buildInitiate({ mutationThunk: MutationThunk api: Api context: ApiContext + internalState: InternalMiddlewareState }) { - const runningQueries: Map< - Dispatch, - Record< - string, - | QueryActionCreatorResult - | InfiniteQueryActionCreatorResult - | undefined - > - > = new Map() - const runningMutations: Map< - Dispatch, - Record | undefined> - > = new Map() + const { runningQueries, runningMutations } = internalState const { unsubscribeQueryResult, diff --git a/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts b/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts index b847dee01a..bffbc32ebe 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts @@ -1,12 +1,12 @@ import type { InternalHandlerBuilder, SubscriptionSelectors } from './types' -import type { SubscriptionState } from '../apiState' +import type { SubscriptionInternalState, SubscriptionState } from '../apiState' import { produceWithPatches } from 'immer' import type { Action } from '@reduxjs/toolkit' -import { countObjectKeys } from '../../utils/countObjectKeys' +import { getOrInsertComputed, createNewMap } from '../../utils/getOrInsert' export const buildBatchedActionsHandler: InternalHandlerBuilder< [actionShouldContinue: boolean, returnValue: SubscriptionSelectors | boolean] -> = ({ api, queryThunk, internalState }) => { +> = ({ api, queryThunk, internalState, mwApi }) => { const subscriptionsPrefix = `${api.reducerPath}/subscriptions` let previousSubscriptions: SubscriptionState = @@ -20,58 +20,63 @@ export const buildBatchedActionsHandler: InternalHandlerBuilder< // Actually intentionally mutate the subscriptions state used in the middleware // This is done to speed up perf when loading many components const actuallyMutateSubscriptions = ( - mutableState: SubscriptionState, + currentSubscriptions: SubscriptionInternalState, action: Action, ) => { if (updateSubscriptionOptions.match(action)) { const { queryCacheKey, requestId, options } = action.payload - if (mutableState?.[queryCacheKey]?.[requestId]) { - mutableState[queryCacheKey]![requestId] = options + const sub = currentSubscriptions.get(queryCacheKey) + if (sub?.has(requestId)) { + sub.set(requestId, options) } return true } if (unsubscribeQueryResult.match(action)) { const { queryCacheKey, requestId } = action.payload - if (mutableState[queryCacheKey]) { - delete mutableState[queryCacheKey]![requestId] + const sub = currentSubscriptions.get(queryCacheKey) + if (sub) { + sub.delete(requestId) } return true } if (api.internalActions.removeQueryResult.match(action)) { - delete mutableState[action.payload.queryCacheKey] + currentSubscriptions.delete(action.payload.queryCacheKey) return true } if (queryThunk.pending.match(action)) { const { meta: { arg, requestId }, } = action - const substate = (mutableState[arg.queryCacheKey] ??= {}) - substate[`${requestId}_running`] = {} + const substate = getOrInsertComputed( + currentSubscriptions, + arg.queryCacheKey, + createNewMap, + ) if (arg.subscribe) { - substate[requestId] = - arg.subscriptionOptions ?? substate[requestId] ?? {} + substate.set( + requestId, + arg.subscriptionOptions ?? substate.get(requestId) ?? {}, + ) } return true } let mutated = false - if ( - queryThunk.fulfilled.match(action) || - queryThunk.rejected.match(action) - ) { - const state = mutableState[action.meta.arg.queryCacheKey] || {} - const key = `${action.meta.requestId}_running` - mutated ||= !!state[key] - delete state[key] - } + if (queryThunk.rejected.match(action)) { const { meta: { condition, arg, requestId }, } = action if (condition && arg.subscribe) { - const substate = (mutableState[arg.queryCacheKey] ??= {}) - substate[requestId] = - arg.subscriptionOptions ?? substate[requestId] ?? {} + const substate = getOrInsertComputed( + currentSubscriptions, + arg.queryCacheKey, + createNewMap, + ) + substate.set( + requestId, + arg.subscriptionOptions ?? substate.get(requestId) ?? {}, + ) mutated = true } @@ -83,12 +88,12 @@ export const buildBatchedActionsHandler: InternalHandlerBuilder< const getSubscriptions = () => internalState.currentSubscriptions const getSubscriptionCount = (queryCacheKey: string) => { const subscriptions = getSubscriptions() - const subscriptionsForQueryArg = subscriptions[queryCacheKey] ?? {} - return countObjectKeys(subscriptionsForQueryArg) + const subscriptionsForQueryArg = subscriptions.get(queryCacheKey) + return subscriptionsForQueryArg?.size ?? 0 } const isRequestSubscribed = (queryCacheKey: string, requestId: string) => { const subscriptions = getSubscriptions() - return !!subscriptions?.[queryCacheKey]?.[requestId] + return !!subscriptions?.get(queryCacheKey)?.get(requestId) } const subscriptionSelectors: SubscriptionSelectors = { @@ -97,6 +102,21 @@ export const buildBatchedActionsHandler: InternalHandlerBuilder< isRequestSubscribed, } + function serializeSubscriptions( + currentSubscriptions: SubscriptionInternalState, + ): SubscriptionState { + // We now use nested Maps for subscriptions, instead of + // plain Records. Stringify this accordingly so we can + // convert it to the shape we need for the store. + return JSON.parse( + JSON.stringify( + Object.fromEntries( + [...currentSubscriptions].map(([k, v]) => [k, Object.fromEntries(v)]), + ), + ), + ) + } + return ( action, mwApi, @@ -106,13 +126,14 @@ export const buildBatchedActionsHandler: InternalHandlerBuilder< ] => { if (!previousSubscriptions) { // Initialize it the first time this handler runs - previousSubscriptions = JSON.parse( - JSON.stringify(internalState.currentSubscriptions), + previousSubscriptions = serializeSubscriptions( + internalState.currentSubscriptions, ) } if (api.util.resetApiState.match(action)) { - previousSubscriptions = internalState.currentSubscriptions = {} + previousSubscriptions = {} + internalState.currentSubscriptions.clear() updateSyncTimer = null return [true, false] } @@ -133,6 +154,15 @@ export const buildBatchedActionsHandler: InternalHandlerBuilder< let actionShouldContinue = true + // HACK Sneak the test-only polling state back out + if ( + process.env.NODE_ENV === 'test' && + typeof action.type === 'string' && + action.type === `${api.reducerPath}/getPolling` + ) { + return [false, internalState.currentPolls] as any + } + if (didMutate) { if (!updateSyncTimer) { // We only use the subscription state for the Redux DevTools at this point, @@ -142,8 +172,8 @@ export const buildBatchedActionsHandler: InternalHandlerBuilder< // In 1.9, it was updated in a microtask, but now we do it at most every 500ms. updateSyncTimer = setTimeout(() => { // Deep clone the current subscription data - const newSubscriptions: SubscriptionState = JSON.parse( - JSON.stringify(internalState.currentSubscriptions), + const newSubscriptions: SubscriptionState = serializeSubscriptions( + internalState.currentSubscriptions, ) // Figure out a smaller diff between original and current const [, patches] = produceWithPatches( diff --git a/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts b/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts index 050b70dba1..fbe70d0de6 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts @@ -35,10 +35,13 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({ internalState, selectors: { selectQueryEntry, selectConfig }, getRunningQueryThunk, + mwApi, }) => { const { removeQueryResult, unsubscribeQueryResult, cacheEntriesUpserted } = api.internalActions + const runningQueries = internalState.runningQueries.get(mwApi.dispatch)! + const canTriggerUnsubscribe = isAnyOf( unsubscribeQueryResult.match, queryThunk.fulfilled, @@ -47,19 +50,14 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({ ) function anySubscriptionsRemainingForKey(queryCacheKey: string) { - const subscriptions = internalState.currentSubscriptions[queryCacheKey] + const subscriptions = internalState.currentSubscriptions.get(queryCacheKey) if (!subscriptions) { return false } - // Check if there are any keys that are NOT _running subscriptions - for (const key in subscriptions) { - if (!key.endsWith('_running')) { - return true - } - } - // Only _running subscriptions remain (or empty) - return false + const hasSubscriptions = subscriptions.size > 0 + const isRunning = runningQueries?.[queryCacheKey] !== undefined + return hasSubscriptions || isRunning } const currentRemovalTimeouts: QueryStateMeta = {} diff --git a/packages/toolkit/src/query/core/buildMiddleware/index.ts b/packages/toolkit/src/query/core/buildMiddleware/index.ts index 1f55b5ef22..4a24e1ec88 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/index.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/index.ts @@ -45,7 +45,7 @@ export function buildMiddleware< ReducerPath extends string, TagTypes extends string, >(input: BuildMiddlewareInput) { - const { reducerPath, queryThunk, api, context } = input + const { reducerPath, queryThunk, api, context, internalState } = input const { apiUid } = context const actions = { @@ -73,10 +73,6 @@ export function buildMiddleware< > = (mwApi) => { let initialized = false - const internalState: InternalMiddlewareState = { - currentSubscriptions: {}, - } - const builderArgs = { ...(input as any as BuildMiddlewareInput< EndpointDefinitions, @@ -86,6 +82,7 @@ export function buildMiddleware< internalState, refetchQuery, isThisApiSliceAction, + mwApi, } const handlers = handlerBuilders.map((build) => build(builderArgs)) diff --git a/packages/toolkit/src/query/core/buildMiddleware/invalidationByTags.ts b/packages/toolkit/src/query/core/buildMiddleware/invalidationByTags.ts index 78d90aa725..ae50030894 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/invalidationByTags.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/invalidationByTags.ts @@ -17,9 +17,8 @@ import type { SubMiddlewareApi, InternalHandlerBuilder, ApiMiddlewareInternalHandler, - InternalMiddlewareState, } from './types' -import { countObjectKeys } from '../../utils/countObjectKeys' +import { getOrInsertComputed, createNewMap } from '../../utils/getOrInsert' export const buildInvalidationByTagsHandler: InternalHandlerBuilder = ({ reducerPath, @@ -111,11 +110,14 @@ export const buildInvalidationByTagsHandler: InternalHandlerBuilder = ({ const valuesArray = Array.from(toInvalidate.values()) for (const { queryCacheKey } of valuesArray) { const querySubState = state.queries[queryCacheKey] - const subscriptionSubState = - internalState.currentSubscriptions[queryCacheKey] ?? {} + const subscriptionSubState = getOrInsertComputed( + internalState.currentSubscriptions, + queryCacheKey, + createNewMap, + ) if (querySubState) { - if (countObjectKeys(subscriptionSubState) === 0) { + if (subscriptionSubState.size === 0) { mwApi.dispatch( removeQueryResult({ queryCacheKey: queryCacheKey as QueryCacheKey, diff --git a/packages/toolkit/src/query/core/buildMiddleware/polling.ts b/packages/toolkit/src/query/core/buildMiddleware/polling.ts index 94b6258842..70f7b177d0 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/polling.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/polling.ts @@ -2,6 +2,7 @@ import type { QueryCacheKey, QuerySubstateIdentifier, Subscribers, + SubscribersInternal, } from '../apiState' import { QueryStatus } from '../apiState' import type { @@ -20,25 +21,25 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ refetchQuery, internalState, }) => { - const currentPolls: QueryStateMeta<{ - nextPollTimestamp: number - timeout?: TimeoutId - pollingInterval: number - }> = {} + const { currentPolls, currentSubscriptions } = internalState + + // Batching state for polling updates + const pendingPollingUpdates = new Set() + let pollingUpdateTimer: ReturnType | null = null const handler: ApiMiddlewareInternalHandler = (action, mwApi) => { if ( api.internalActions.updateSubscriptionOptions.match(action) || api.internalActions.unsubscribeQueryResult.match(action) ) { - updatePollingInterval(action.payload, mwApi) + schedulePollingUpdate(action.payload.queryCacheKey, mwApi) } if ( queryThunk.pending.match(action) || (queryThunk.rejected.match(action) && action.meta.condition) ) { - updatePollingInterval(action.meta.arg, mwApi) + schedulePollingUpdate(action.meta.arg.queryCacheKey, mwApi) } if ( @@ -50,6 +51,27 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ if (api.util.resetApiState.match(action)) { clearPolls() + // Clear any pending updates + if (pollingUpdateTimer) { + clearTimeout(pollingUpdateTimer) + pollingUpdateTimer = null + } + pendingPollingUpdates.clear() + } + } + + function schedulePollingUpdate(queryCacheKey: string, api: SubMiddlewareApi) { + pendingPollingUpdates.add(queryCacheKey) + + if (!pollingUpdateTimer) { + pollingUpdateTimer = setTimeout(() => { + // Process all pending updates in a single batch + for (const key of pendingPollingUpdates) { + updatePollingInterval({ queryCacheKey: key as any }, api) + } + pendingPollingUpdates.clear() + pollingUpdateTimer = null + }, 0) } } @@ -59,7 +81,7 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ ) { const state = api.getState()[reducerPath] const querySubState = state.queries[queryCacheKey] - const subscriptions = internalState.currentSubscriptions[queryCacheKey] + const subscriptions = currentSubscriptions.get(queryCacheKey) if (!querySubState || querySubState.status === QueryStatus.uninitialized) return @@ -73,7 +95,7 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ ) { const state = api.getState()[reducerPath] const querySubState = state.queries[queryCacheKey] - const subscriptions = internalState.currentSubscriptions[queryCacheKey] + const subscriptions = currentSubscriptions.get(queryCacheKey) if (!querySubState || querySubState.status === QueryStatus.uninitialized) return @@ -82,7 +104,7 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ findLowestPollingInterval(subscriptions) if (!Number.isFinite(lowestPollingInterval)) return - const currentPoll = currentPolls[queryCacheKey] + const currentPoll = currentPolls.get(queryCacheKey) if (currentPoll?.timeout) { clearTimeout(currentPoll.timeout) @@ -91,7 +113,7 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ const nextPollTimestamp = Date.now() + lowestPollingInterval - currentPolls[queryCacheKey] = { + currentPolls.set(queryCacheKey, { nextPollTimestamp, pollingInterval: lowestPollingInterval, timeout: setTimeout(() => { @@ -100,7 +122,7 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ } startNextPoll({ queryCacheKey }, api) }, lowestPollingInterval), - } + }) } function updatePollingInterval( @@ -109,7 +131,7 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ ) { const state = api.getState()[reducerPath] const querySubState = state.queries[queryCacheKey] - const subscriptions = internalState.currentSubscriptions[queryCacheKey] + const subscriptions = currentSubscriptions.get(queryCacheKey) if (!querySubState || querySubState.status === QueryStatus.uninitialized) { return @@ -117,12 +139,21 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ const { lowestPollingInterval } = findLowestPollingInterval(subscriptions) + // HACK add extra data to track how many times this has been called in tests + // yes we're mutating a nonexistent field on a Map here + if (process.env.NODE_ENV === 'test') { + const updateCounters = ((currentPolls as any).pollUpdateCounters ??= {}) + updateCounters[queryCacheKey] ??= 0 + updateCounters[queryCacheKey]++ + } + if (!Number.isFinite(lowestPollingInterval)) { cleanupPollForKey(queryCacheKey) return } - const currentPoll = currentPolls[queryCacheKey] + const currentPoll = currentPolls.get(queryCacheKey) + const nextPollTimestamp = Date.now() + lowestPollingInterval if (!currentPoll || nextPollTimestamp < currentPoll.nextPollTimestamp) { @@ -131,30 +162,33 @@ export const buildPollingHandler: InternalHandlerBuilder = ({ } function cleanupPollForKey(key: string) { - const existingPoll = currentPolls[key] + const existingPoll = currentPolls.get(key) if (existingPoll?.timeout) { clearTimeout(existingPoll.timeout) } - delete currentPolls[key] + currentPolls.delete(key) } function clearPolls() { - for (const key of Object.keys(currentPolls)) { + for (const key of currentPolls.keys()) { cleanupPollForKey(key) } } - function findLowestPollingInterval(subscribers: Subscribers = {}) { + function findLowestPollingInterval( + subscribers: SubscribersInternal = new Map(), + ) { let skipPollingIfUnfocused: boolean | undefined = false let lowestPollingInterval = Number.POSITIVE_INFINITY - for (let key in subscribers) { - if (!!subscribers[key].pollingInterval) { + + for (const entry of subscribers.values()) { + if (!!entry.pollingInterval) { lowestPollingInterval = Math.min( - subscribers[key].pollingInterval!, + entry.pollingInterval!, lowestPollingInterval, ) skipPollingIfUnfocused = - subscribers[key].skipPollingIfUnfocused || skipPollingIfUnfocused + entry.skipPollingIfUnfocused || skipPollingIfUnfocused } } diff --git a/packages/toolkit/src/query/core/buildMiddleware/types.ts b/packages/toolkit/src/query/core/buildMiddleware/types.ts index 2db9cb775d..e3ffe3340e 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/types.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/types.ts @@ -17,6 +17,7 @@ import type { QueryStatus, QuerySubState, RootState, + SubscriptionInternalState, SubscriptionState, } from '../apiState' import type { @@ -26,18 +27,42 @@ import type { QueryThunkArg, ThunkResult, } from '../buildThunks' -import type { QueryActionCreatorResult } from '../buildInitiate' +import type { + InfiniteQueryActionCreatorResult, + MutationActionCreatorResult, + QueryActionCreatorResult, +} from '../buildInitiate' import type { AllSelectors } from '../buildSelectors' export type QueryStateMeta = Record export type TimeoutId = ReturnType +type QueryPollState = { + nextPollTimestamp: number + timeout?: TimeoutId + pollingInterval: number +} + export interface InternalMiddlewareState { - currentSubscriptions: SubscriptionState + currentSubscriptions: SubscriptionInternalState + currentPolls: Map + runningQueries: Map< + Dispatch, + Record< + string, + | QueryActionCreatorResult + | InfiniteQueryActionCreatorResult + | undefined + > + > + runningMutations: Map< + Dispatch, + Record | undefined> + > } export interface SubscriptionSelectors { - getSubscriptions: () => SubscriptionState + getSubscriptions: () => SubscriptionInternalState getSubscriptionCount: (queryCacheKey: string) => number isRequestSubscribed: (queryCacheKey: string, requestId: string) => boolean } @@ -59,6 +84,7 @@ export interface BuildMiddlewareInput< endpointName: string, queryArgs: any, ) => (dispatch: Dispatch) => QueryActionCreatorResult | undefined + internalState: InternalMiddlewareState } export type SubMiddlewareApi = MiddlewareAPI< @@ -77,6 +103,10 @@ export interface BuildSubMiddlewareInput ): ThunkAction, any, any, UnknownAction> isThisApiSliceAction: (action: Action) => boolean selectors: AllSelectors + mwApi: MiddlewareAPI< + ThunkDispatch, + RootState + > } export type SubMiddlewareBuilder = ( diff --git a/packages/toolkit/src/query/core/buildMiddleware/windowEventHandling.ts b/packages/toolkit/src/query/core/buildMiddleware/windowEventHandling.ts index e3c17b6513..53b5c87230 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/windowEventHandling.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/windowEventHandling.ts @@ -35,23 +35,19 @@ export const buildWindowEventHandler: InternalHandlerBuilder = ({ const subscriptions = internalState.currentSubscriptions context.batch(() => { - for (const queryCacheKey of Object.keys(subscriptions)) { + for (const queryCacheKey of subscriptions.keys()) { const querySubState = queries[queryCacheKey] - const subscriptionSubState = subscriptions[queryCacheKey] + const subscriptionSubState = subscriptions.get(queryCacheKey) if (!subscriptionSubState || !querySubState) continue + const values = [...subscriptionSubState.values()] const shouldRefetch = - Object.values(subscriptionSubState).some( - (sub) => sub[type] === true, - ) || - (Object.values(subscriptionSubState).every( - (sub) => sub[type] === undefined, - ) && - state.config[type]) + values.some((sub) => sub[type] === true) || + (values.every((sub) => sub[type] === undefined) && state.config[type]) if (shouldRefetch) { - if (countObjectKeys(subscriptionSubState) === 0) { + if (subscriptionSubState.size === 0) { api.dispatch( removeQueryResult({ queryCacheKey: queryCacheKey as QueryCacheKey, diff --git a/packages/toolkit/src/query/core/module.ts b/packages/toolkit/src/query/core/module.ts index 4133203e13..e9bfa08b76 100644 --- a/packages/toolkit/src/query/core/module.ts +++ b/packages/toolkit/src/query/core/module.ts @@ -71,6 +71,7 @@ import type { import { buildThunks } from './buildThunks' import { createSelector as _createSelector } from './rtkImports' import { onFocus, onFocusLost, onOffline, onOnline } from './setupListeners' +import type { InternalMiddlewareState } from './buildMiddleware/types' /** * `ifOlderThan` - (default: `false` | `number`) - _number is value in seconds_ @@ -618,6 +619,13 @@ export const coreModule = ({ }) safeAssign(api.internalActions, sliceActions) + const internalState: InternalMiddlewareState = { + currentSubscriptions: new Map(), + currentPolls: new Map(), + runningQueries: new Map(), + runningMutations: new Map(), + } + const { buildInitiateQuery, buildInitiateInfiniteQuery, @@ -633,6 +641,7 @@ export const coreModule = ({ api, serializeQueryArgs: serializeQueryArgs as any, context, + internalState, }) safeAssign(api.util, { @@ -652,6 +661,7 @@ export const coreModule = ({ assertTagType, selectors, getRunningQueryThunk, + internalState, }) safeAssign(api.util, middlewareActions) diff --git a/packages/toolkit/src/query/tests/buildHooks.test.tsx b/packages/toolkit/src/query/tests/buildHooks.test.tsx index 0a72ce8a58..115a475a9f 100644 --- a/packages/toolkit/src/query/tests/buildHooks.test.tsx +++ b/packages/toolkit/src/query/tests/buildHooks.test.tsx @@ -1068,7 +1068,7 @@ describe('hooks tests', () => { const checkNumSubscriptions = (arg: string, count: number) => { const subscriptions = getSubscriptions() - const cacheKeyEntry = subscriptions[arg] + const cacheKeyEntry = subscriptions.get(arg) if (cacheKeyEntry) { const subscriptionCount = Object.keys(cacheKeyEntry) //getSubscriptionCount(arg) @@ -3780,7 +3780,7 @@ describe('skip behavior', () => { expect(getSubscriptionCount('getUser(1)')).toBe(0) // also no subscription on `getUser(skipToken)` or similar: - expect(getSubscriptions()).toEqual({}) + expect(getSubscriptions().size).toBe(0) rerender([1]) @@ -3791,7 +3791,7 @@ describe('skip behavior', () => { expect(result.current).toMatchObject({ status: QueryStatus.fulfilled }) await waitMs(1) expect(getSubscriptionCount('getUser(1)')).toBe(1) - expect(getSubscriptions()).not.toEqual({}) + expect(getSubscriptions().size).toBe(1) rerender([skipToken]) @@ -3821,7 +3821,7 @@ describe('skip behavior', () => { expect(getSubscriptionCount('nestedValue')).toBe(0) // also no subscription on `getUser(skipToken)` or similar: - expect(getSubscriptions()).toEqual({}) + expect(getSubscriptions().size).toBe(0) rerender([{ param: { nested: 'nestedValue' } }]) @@ -3833,7 +3833,7 @@ describe('skip behavior', () => { await waitMs(1) expect(getSubscriptionCount('nestedValue')).toBe(1) - expect(getSubscriptions()).not.toEqual({}) + expect(getSubscriptions().size).toBe(1) rerender([skipToken]) diff --git a/packages/toolkit/src/query/tests/buildInitiate.test.tsx b/packages/toolkit/src/query/tests/buildInitiate.test.tsx index bbced70b49..74d1d61e51 100644 --- a/packages/toolkit/src/query/tests/buildInitiate.test.tsx +++ b/packages/toolkit/src/query/tests/buildInitiate.test.tsx @@ -79,23 +79,11 @@ describe('calling initiate without a cache entry, with subscribe: false still re expect(isRequestSubscribed('increment(undefined)', promise.requestId)).toBe( false, ) - expect( - isRequestSubscribed( - 'increment(undefined)', - `${promise.requestId}_running`, - ), - ).toBe(true) await expect(promise).resolves.toMatchObject({ data: 0, status: 'fulfilled', }) - expect( - isRequestSubscribed( - 'increment(undefined)', - `${promise.requestId}_running`, - ), - ).toBe(false) }) test('rejected query', async () => { @@ -107,16 +95,10 @@ describe('calling initiate without a cache entry, with subscribe: false still re expect(isRequestSubscribed('failing(undefined)', promise.requestId)).toBe( false, ) - expect( - isRequestSubscribed('failing(undefined)', `${promise.requestId}_running`), - ).toBe(true) await expect(promise).resolves.toMatchObject({ status: 'rejected', }) - expect( - isRequestSubscribed('failing(undefined)', `${promise.requestId}_running`), - ).toBe(false) }) }) diff --git a/packages/toolkit/src/query/tests/buildMiddleware.test.tsx b/packages/toolkit/src/query/tests/buildMiddleware.test.tsx index 0a03396302..2abac64498 100644 --- a/packages/toolkit/src/query/tests/buildMiddleware.test.tsx +++ b/packages/toolkit/src/query/tests/buildMiddleware.test.tsx @@ -26,11 +26,13 @@ const api = createApi({ providesTags: ['Bread'], }), invalidateFruit: build.mutation({ - query: (fruit?: 'Banana' | 'Bread' | null) => ({ url: `invalidate/fruit/${fruit || ''}` }), + query: (fruit?: 'Banana' | 'Bread' | null) => ({ + url: `invalidate/fruit/${fruit || ''}`, + }), invalidatesTags(result, error, arg) { return [arg] - } - }) + }, + }), }), }) const { getBanana, getBread, invalidateFruit } = api.endpoints @@ -77,9 +79,11 @@ it('invalidates the specified tags', async () => { ) }) -it('invalidates tags correctly when null or undefined are provided as tags', async() =>{ +it('invalidates tags correctly when null or undefined are provided as tags', async () => { await storeRef.store.dispatch(getBanana.initiate(1)) - await storeRef.store.dispatch(api.util.invalidateTags([undefined, null, 'Banana'])) + await storeRef.store.dispatch( + api.util.invalidateTags([undefined, null, 'Banana']), + ) // Slight pause to let the middleware run and such await delay(20) @@ -96,41 +100,116 @@ it('invalidates tags correctly when null or undefined are provided as tags', asy expect(storeRef.store.getState().actions).toMatchSequence(...apiActions) }) - it.each([ - { tags: [undefined, null, 'Bread'] as Parameters['0'] }, - { tags: [undefined, null], }, { tags: [] }] -)('does not invalidate with tags=$tags if no query matches', async ({ tags }) => { - await storeRef.store.dispatch(getBanana.initiate(1)) - await storeRef.store.dispatch(api.util.invalidateTags(tags)) + { + tags: [undefined, null, 'Bread'] as Parameters< + typeof api.util.invalidateTags + >['0'], + }, + { tags: [undefined, null] }, + { tags: [] }, +])( + 'does not invalidate with tags=$tags if no query matches', + async ({ tags }) => { + await storeRef.store.dispatch(getBanana.initiate(1)) + await storeRef.store.dispatch(api.util.invalidateTags(tags)) + + // Slight pause to let the middleware run and such + await delay(20) + + const apiActions = [ + api.internalActions.middlewareRegistered.match, + getBanana.matchPending, + getBanana.matchFulfilled, + api.util.invalidateTags.match, + ] + + expect(storeRef.store.getState().actions).toMatchSequence(...apiActions) + }, +) - // Slight pause to let the middleware run and such - await delay(20) +it.each([ + { mutationArg: 'Bread' as 'Bread' | null | undefined }, + { mutationArg: undefined }, + { mutationArg: null }, +])( + 'does not invalidate queries when a mutation with tags=[$mutationArg] runs and does not match anything', + async ({ mutationArg }) => { + await storeRef.store.dispatch(getBanana.initiate(1)) + await storeRef.store.dispatch(invalidateFruit.initiate(mutationArg)) + + // Slight pause to let the middleware run and such + await delay(20) + + const apiActions = [ + api.internalActions.middlewareRegistered.match, + getBanana.matchPending, + getBanana.matchFulfilled, + invalidateFruit.matchPending, + invalidateFruit.matchFulfilled, + ] + + expect(storeRef.store.getState().actions).toMatchSequence(...apiActions) + }, +) + +it('correctly stringifies subscription state and dispatches subscriptionsUpdated', async () => { + // Create a fresh store for this test to avoid interference + const testStoreRef = setupApiStore( + api, + { + ...actionsReducer, + }, + { withoutListeners: true }, + ) - const apiActions = [ - api.internalActions.middlewareRegistered.match, - getBanana.matchPending, - getBanana.matchFulfilled, - api.util.invalidateTags.match, - ] + // Start multiple subscriptions + const subscription1 = testStoreRef.store.dispatch( + getBanana.initiate(1, { + subscriptionOptions: { pollingInterval: 1000 }, + }), + ) + const subscription2 = testStoreRef.store.dispatch( + getBanana.initiate(2, { + subscriptionOptions: { refetchOnFocus: true }, + }), + ) + const subscription3 = testStoreRef.store.dispatch( + api.endpoints.getBananas.initiate(), + ) - expect(storeRef.store.getState().actions).toMatchSequence(...apiActions) + // Wait for the subscriptions to be established + await Promise.all([subscription1, subscription2, subscription3]) + + // Wait for the subscription sync timer (500ms + buffer) + await delay(600) + + // Check the final subscription state in the store + const finalState = testStoreRef.store.getState() + const subscriptionState = finalState[api.reducerPath].subscriptions + + // Should have subscriptions for getBanana(1), getBanana(2), and getBananas() + expect(subscriptionState).toMatchObject({ + 'getBanana(1)': { + [subscription1.requestId]: { pollingInterval: 1000 }, + }, + 'getBanana(2)': { + [subscription2.requestId]: { refetchOnFocus: true }, + }, + 'getBananas(undefined)': { + [subscription3.requestId]: {}, + }, + }) + + // Verify the subscription entries have the expected structure + expect(Object.keys(subscriptionState)).toHaveLength(3) + expect(subscriptionState['getBanana(1)']?.[subscription1.requestId]).toEqual({ + pollingInterval: 1000, + }) + expect(subscriptionState['getBanana(2)']?.[subscription2.requestId]).toEqual({ + refetchOnFocus: true, + }) + expect( + subscriptionState['getBananas(undefined)']?.[subscription3.requestId], + ).toEqual({}) }) - -it.each([{ mutationArg: 'Bread' as "Bread" | null | undefined }, { mutationArg: undefined }, { mutationArg: null }])('does not invalidate queries when a mutation with tags=[$mutationArg] runs and does not match anything', async ({ mutationArg }) => { - await storeRef.store.dispatch(getBanana.initiate(1)) - await storeRef.store.dispatch(invalidateFruit.initiate(mutationArg)) - - // Slight pause to let the middleware run and such - await delay(20) - - const apiActions = [ - api.internalActions.middlewareRegistered.match, - getBanana.matchPending, - getBanana.matchFulfilled, - invalidateFruit.matchPending, - invalidateFruit.matchFulfilled, - ] - - expect(storeRef.store.getState().actions).toMatchSequence(...apiActions) -}) \ No newline at end of file diff --git a/packages/toolkit/src/query/tests/infiniteQueries.test.ts b/packages/toolkit/src/query/tests/infiniteQueries.test.ts index 1325d135af..da5e10ae70 100644 --- a/packages/toolkit/src/query/tests/infiniteQueries.test.ts +++ b/packages/toolkit/src/query/tests/infiniteQueries.test.ts @@ -16,6 +16,7 @@ describe('Infinite queries', () => { name: string } + type HitCounter = { page: number; hitCounter: number } let counters: Record = {} let queryCounter = 0 @@ -88,39 +89,41 @@ describe('Infinite queries', () => { }), }) - let hitCounter = 0 - - type HitCounter = { page: number; hitCounter: number } + function createCountersApi() { + let hitCounter = 0 - const countersApi = createApi({ - baseQuery: fakeBaseQuery(), - tagTypes: ['Counter'], - endpoints: (build) => ({ - counters: build.infiniteQuery({ - queryFn({ pageParam }) { - hitCounter++ + const countersApi = createApi({ + baseQuery: fakeBaseQuery(), + tagTypes: ['Counter'], + endpoints: (build) => ({ + counters: build.infiniteQuery({ + queryFn({ pageParam }) { + hitCounter++ - return { data: { page: pageParam, hitCounter } } - }, - infiniteQueryOptions: { - initialPageParam: 0, - getNextPageParam: ( - lastPage, - allPages, - lastPageParam, - allPageParams, - ) => lastPageParam + 1, - }, - providesTags: ['Counter'], - }), - mutation: build.mutation({ - queryFn: async () => { - return { data: null } - }, - invalidatesTags: ['Counter'], + return { data: { page: pageParam, hitCounter } } + }, + infiniteQueryOptions: { + initialPageParam: 0, + getNextPageParam: ( + lastPage, + allPages, + lastPageParam, + allPageParams, + ) => lastPageParam + 1, + }, + providesTags: ['Counter'], + }), + mutation: build.mutation({ + queryFn: async () => { + return { data: null } + }, + invalidatesTags: ['Counter'], + }), }), - }), - }) + }) + + return countersApi + } let storeRef = setupApiStore( pokemonApi, @@ -155,7 +158,6 @@ describe('Infinite queries', () => { counters = {} - hitCounter = 0 queryCounter = 0 }) @@ -411,6 +413,8 @@ describe('Infinite queries', () => { } } + const countersApi = createCountersApi() + const storeRef = setupApiStore( countersApi, { ...actionsReducer }, @@ -465,6 +469,8 @@ describe('Infinite queries', () => { } } + const countersApi = createCountersApi() + const storeRef = setupApiStore( countersApi, { ...actionsReducer }, @@ -528,6 +534,7 @@ describe('Infinite queries', () => { }) test('Refetches on polling', async () => { + const countersApi = createCountersApi() const checkResultData = ( result: InfiniteQueryResult, expectedValues: HitCounter[], diff --git a/packages/toolkit/src/query/tests/polling.test.tsx b/packages/toolkit/src/query/tests/polling.test.tsx index 425a0bf804..cff9ab7449 100644 --- a/packages/toolkit/src/query/tests/polling.test.tsx +++ b/packages/toolkit/src/query/tests/polling.test.tsx @@ -1,4 +1,5 @@ import { createApi } from '@reduxjs/toolkit/query' +import type { QueryActionCreatorResult } from '@reduxjs/toolkit/query' import { delay } from 'msw' import { setupApiStore } from '../../tests/utils/helpers' import type { SubscriptionSelectors } from '../core/buildMiddleware/types' @@ -29,10 +30,15 @@ beforeEach(() => { ;({ getSubscriptions } = storeRef.store.dispatch( api.internalActions.internal_getRTKQSubscriptions(), ) as unknown as SubscriptionSelectors) + + const currentPolls = storeRef.store.dispatch({ + type: `${api.reducerPath}/getPolling`, + }) as any + ;(currentPolls as any).pollUpdateCounters = {} }) const getSubscribersForQueryCacheKey = (queryCacheKey: string) => - getSubscriptions()[queryCacheKey] || {} + getSubscriptions().get(queryCacheKey) ?? new Map() const createSubscriptionGetter = (queryCacheKey: string) => () => getSubscribersForQueryCacheKey(queryCacheKey) @@ -66,14 +72,14 @@ describe('polling tests', () => { const getSubs = createSubscriptionGetter(queryCacheKey) await delay(1) - expect(Object.keys(getSubs())).toHaveLength(1) - expect(getSubs()[requestId].pollingInterval).toBe(10) + expect(getSubs().size).toBe(1) + expect(getSubs()?.get(requestId)?.pollingInterval).toBe(10) subscription.updateSubscriptionOptions({ pollingInterval: 20 }) await delay(1) - expect(Object.keys(getSubs())).toHaveLength(1) - expect(getSubs()[requestId].pollingInterval).toBe(20) + expect(getSubs().size).toBe(1) + expect(getSubs()?.get(requestId)?.pollingInterval).toBe(20) }) it(`doesn't replace the interval when removing a shared query instance with a poll `, async () => { @@ -95,12 +101,12 @@ describe('polling tests', () => { const getSubs = createSubscriptionGetter(subscriptionOne.queryCacheKey) - expect(Object.keys(getSubs())).toHaveLength(2) + expect(getSubs().size).toBe(2) subscriptionOne.unsubscribe() await delay(1) - expect(Object.keys(getSubs())).toHaveLength(1) + expect(getSubs().size).toBe(1) }) it('uses lowest specified interval when two components are mounted', async () => { @@ -155,7 +161,7 @@ describe('polling tests', () => { const callsWithoutSkip = mockBaseQuery.mock.calls.length expect(callsWithSkip).toBe(1) - expect(callsWithoutSkip).toBeGreaterThan(2) + expect(callsWithoutSkip).toBeGreaterThanOrEqual(2) storeRef.store.dispatch(api.util.resetApiState()) }) @@ -218,8 +224,8 @@ describe('polling tests', () => { const getSubs = createSubscriptionGetter(queryCacheKey) await delay(1) - expect(Object.keys(getSubs())).toHaveLength(1) - expect(getSubs()[requestId].skipPollingIfUnfocused).toBe(false) + expect(getSubs().size).toBe(1) + expect(getSubs().get(requestId)?.skipPollingIfUnfocused).toBe(false) subscription.updateSubscriptionOptions({ pollingInterval: 20, @@ -227,7 +233,54 @@ describe('polling tests', () => { }) await delay(1) - expect(Object.keys(getSubs())).toHaveLength(1) - expect(getSubs()[requestId].skipPollingIfUnfocused).toBe(true) + expect(getSubs().size).toBe(1) + expect(getSubs().get(requestId)?.skipPollingIfUnfocused).toBe(true) + }) + + it('should minimize polling recalculations when adding multiple subscribers', async () => { + // Reset any existing state + const storeRef = setupApiStore(api, undefined, { + withoutTestLifecycles: true, + }) + + const SUBSCRIBER_COUNT = 10 + const subscriptions: QueryActionCreatorResult[] = [] + + // Add 10 subscribers to the same endpoint with polling enabled + for (let i = 0; i < SUBSCRIBER_COUNT; i++) { + const subscription = storeRef.store.dispatch( + getPosts.initiate(1, { + subscriptionOptions: { pollingInterval: 1000 }, + subscribe: true, + }), + ) + subscriptions.push(subscription) + } + + // Wait a bit for all subscriptions to be processed + await Promise.all(subscriptions) + + // Wait for the poll update timer + await delay(25) + + // Get the polling state using the secret "getPolling" action + const currentPolls = storeRef.store.dispatch({ + type: `${api.reducerPath}/getPolling`, + }) as any + + // Get the query cache key for our endpoint + const queryCacheKey = subscriptions[0].queryCacheKey + + // Check the poll update counters + const pollUpdateCounters = currentPolls.pollUpdateCounters || {} + const updateCount = pollUpdateCounters[queryCacheKey] || 0 + + // With batching optimization, this should be much lower than SUBSCRIBER_COUNT + // Ideally 1, but could be slightly higher due to timing + expect(updateCount).toBeGreaterThanOrEqual(1) + expect(updateCount).toBeLessThanOrEqual(2) + + // Clean up subscriptions + subscriptions.forEach((sub) => sub.unsubscribe()) }) }) diff --git a/packages/toolkit/src/query/utils/getOrInsert.ts b/packages/toolkit/src/query/utils/getOrInsert.ts index 124da032ea..8ae351e02f 100644 --- a/packages/toolkit/src/query/utils/getOrInsert.ts +++ b/packages/toolkit/src/query/utils/getOrInsert.ts @@ -1,3 +1,7 @@ +// Duplicate some of the utils in `/src/utils` to ensure +// we don't end up dragging in larger chunks of the RTK core +// into the RTKQ bundle + export function getOrInsert( map: WeakMap, key: K, @@ -13,3 +17,25 @@ export function getOrInsert( return map.set(key, value).get(key) as V } + +export function getOrInsertComputed( + map: WeakMap, + key: K, + compute: (key: K) => V, +): V +export function getOrInsertComputed( + map: Map, + key: K, + compute: (key: K) => V, +): V +export function getOrInsertComputed( + map: Map | WeakMap, + key: K, + compute: (key: K) => V, +): V { + if (map.has(key)) return map.get(key) as V + + return map.set(key, compute(key)).get(key) as V +} + +export const createNewMap = () => new Map()