diff --git a/packages/toolkit/src/query/core/buildInitiate.ts b/packages/toolkit/src/query/core/buildInitiate.ts index 54ba080373..acb6a0e734 100644 --- a/packages/toolkit/src/query/core/buildInitiate.ts +++ b/packages/toolkit/src/query/core/buildInitiate.ts @@ -22,7 +22,7 @@ import { type QueryDefinition, type ResultTypeFrom, } from '../endpointDefinitions' -import { countObjectKeys, getOrInsert, isNotNullish } from '../utils' +import { filterNullishValues } from '../utils' import type { InfiniteData, InfiniteQueryConfigOptions, @@ -271,7 +271,7 @@ export function buildInitiate({ mutationThunk, api, context, - internalState, + getInternalState, }: { serializeQueryArgs: InternalSerializeQueryArgs queryThunk: QueryThunk @@ -279,9 +279,12 @@ export function buildInitiate({ mutationThunk: MutationThunk api: Api context: ApiContext - internalState: InternalMiddlewareState + getInternalState: (dispatch: Dispatch) => InternalMiddlewareState }) { - const { runningQueries, runningMutations } = internalState + const getRunningQueries = (dispatch: Dispatch) => + getInternalState(dispatch)?.runningQueries + const getRunningMutations = (dispatch: Dispatch) => + getInternalState(dispatch)?.runningMutations const { unsubscribeQueryResult, @@ -306,7 +309,7 @@ export function buildInitiate({ endpointDefinition, endpointName, }) - return runningQueries.get(dispatch)?.[queryCacheKey] as + return getRunningQueries(dispatch)?.get(queryCacheKey) as | QueryActionCreatorResult | InfiniteQueryActionCreatorResult | undefined @@ -322,7 +325,7 @@ export function buildInitiate({ fixedCacheKeyOrRequestId: string, ) { return (dispatch: Dispatch) => { - return runningMutations.get(dispatch)?.[fixedCacheKeyOrRequestId] as + return getRunningMutations(dispatch)?.get(fixedCacheKeyOrRequestId) as | MutationActionCreatorResult | undefined } @@ -330,12 +333,12 @@ export function buildInitiate({ function getRunningQueriesThunk() { return (dispatch: Dispatch) => - Object.values(runningQueries.get(dispatch) || {}).filter(isNotNullish) + filterNullishValues(getRunningQueries(dispatch)) } function getRunningMutationsThunk() { return (dispatch: Dispatch) => - Object.values(runningMutations.get(dispatch) || {}).filter(isNotNullish) + filterNullishValues(getRunningMutations(dispatch)) } function middlewareWarning(dispatch: Dispatch) { @@ -429,7 +432,7 @@ You must add the middleware for RTK-Query to function correctly!`, const skippedSynchronously = stateAfter.requestId !== requestId - const runningQuery = runningQueries.get(dispatch)?.[queryCacheKey] + const runningQuery = getRunningQueries(dispatch)?.get(queryCacheKey) const selectFromState = () => selector(getState()) const statePromise: AnyActionCreatorResult = Object.assign( @@ -489,14 +492,11 @@ You must add the middleware for RTK-Query to function correctly!`, ) if (!runningQuery && !skippedSynchronously && !forceQueryFn) { - const running = getOrInsert(runningQueries, dispatch, {}) - running[queryCacheKey] = statePromise + const runningQueries = getRunningQueries(dispatch)! + runningQueries.set(queryCacheKey, statePromise) statePromise.then(() => { - delete running[queryCacheKey] - if (!countObjectKeys(running)) { - runningQueries.delete(dispatch) - } + runningQueries.delete(queryCacheKey) }) } @@ -559,23 +559,17 @@ You must add the middleware for RTK-Query to function correctly!`, reset, }) - const running = runningMutations.get(dispatch) || {} - runningMutations.set(dispatch, running) - running[requestId] = ret + const runningMutations = getRunningMutations(dispatch)! + + runningMutations.set(requestId, ret) ret.then(() => { - delete running[requestId] - if (!countObjectKeys(running)) { - runningMutations.delete(dispatch) - } + runningMutations.delete(requestId) }) if (fixedCacheKey) { - running[fixedCacheKey] = ret + runningMutations.set(fixedCacheKey, ret) ret.then(() => { - if (running[fixedCacheKey] === ret) { - delete running[fixedCacheKey] - if (!countObjectKeys(running)) { - runningMutations.delete(dispatch) - } + if (runningMutations.get(fixedCacheKey) === ret) { + runningMutations.delete(fixedCacheKey) } }) } diff --git a/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts b/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts index 94737d3ea3..02e8ac8aca 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts @@ -13,27 +13,27 @@ export type ReferenceCacheCollection = never /** * @example - * ```ts - * // codeblock-meta title="keepUnusedDataFor example" - * import { createApi, fetchBaseQuery } from '@reduxjs/toolkit/query/react' - * interface Post { - * id: number - * name: string - * } - * type PostsResponse = Post[] - * - * const api = createApi({ - * baseQuery: fetchBaseQuery({ baseUrl: '/' }), - * endpoints: (build) => ({ - * getPosts: build.query({ - * query: () => 'posts', - * // highlight-start - * keepUnusedDataFor: 5 - * // highlight-end - * }) - * }) - * }) - * ``` + * ```ts + * // codeblock-meta title="keepUnusedDataFor example" + * import { createApi, fetchBaseQuery } from '@reduxjs/toolkit/query/react' + * interface Post { + * id: number + * name: string + * } + * type PostsResponse = Post[] + * + * const api = createApi({ + * baseQuery: fetchBaseQuery({ baseUrl: '/' }), + * endpoints: (build) => ({ + * getPosts: build.query({ + * query: () => 'posts', + * // highlight-start + * keepUnusedDataFor: 5 + * // highlight-end + * }) + * }) + * }) + * ``` */ export type CacheCollectionQueryExtraOptions = { /** @@ -64,8 +64,6 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({ const { removeQueryResult, unsubscribeQueryResult, cacheEntriesUpserted } = api.internalActions - const runningQueries = internalState.runningQueries.get(mwApi.dispatch)! - const canTriggerUnsubscribe = isAnyOf( unsubscribeQueryResult.match, queryThunk.fulfilled, @@ -80,8 +78,7 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({ } const hasSubscriptions = subscriptions.size > 0 - const isRunning = runningQueries?.[queryCacheKey] !== undefined - return hasSubscriptions || isRunning + return hasSubscriptions } const currentRemovalTimeouts: QueryStateMeta = {} diff --git a/packages/toolkit/src/query/core/buildMiddleware/index.ts b/packages/toolkit/src/query/core/buildMiddleware/index.ts index 4a24e1ec88..0a7bc92222 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/index.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/index.ts @@ -45,7 +45,7 @@ export function buildMiddleware< ReducerPath extends string, TagTypes extends string, >(input: BuildMiddlewareInput) { - const { reducerPath, queryThunk, api, context, internalState } = input + const { reducerPath, queryThunk, api, context, getInternalState } = input const { apiUid } = context const actions = { @@ -73,6 +73,8 @@ export function buildMiddleware< > = (mwApi) => { let initialized = false + const internalState = getInternalState(mwApi.dispatch) + const builderArgs = { ...(input as any as BuildMiddlewareInput< EndpointDefinitions, diff --git a/packages/toolkit/src/query/core/buildMiddleware/types.ts b/packages/toolkit/src/query/core/buildMiddleware/types.ts index e3ffe3340e..d0b3aa288d 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/types.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/types.ts @@ -47,18 +47,12 @@ export interface InternalMiddlewareState { currentSubscriptions: SubscriptionInternalState currentPolls: Map runningQueries: Map< - Dispatch, - Record< - string, - | QueryActionCreatorResult - | InfiniteQueryActionCreatorResult - | undefined - > - > - runningMutations: Map< - Dispatch, - Record | undefined> + string, + | QueryActionCreatorResult + | InfiniteQueryActionCreatorResult + | undefined > + runningMutations: Map | undefined> } export interface SubscriptionSelectors { @@ -84,7 +78,7 @@ export interface BuildMiddlewareInput< endpointName: string, queryArgs: any, ) => (dispatch: Dispatch) => QueryActionCreatorResult | undefined - internalState: InternalMiddlewareState + getInternalState: (dispatch: Dispatch) => InternalMiddlewareState } export type SubMiddlewareApi = MiddlewareAPI< diff --git a/packages/toolkit/src/query/core/module.ts b/packages/toolkit/src/query/core/module.ts index e9bfa08b76..e2857fffcc 100644 --- a/packages/toolkit/src/query/core/module.ts +++ b/packages/toolkit/src/query/core/module.ts @@ -3,6 +3,7 @@ */ import type { ActionCreatorWithPayload, + Dispatch, Middleware, Reducer, ThunkAction, @@ -72,6 +73,7 @@ import { buildThunks } from './buildThunks' import { createSelector as _createSelector } from './rtkImports' import { onFocus, onFocusLost, onOffline, onOnline } from './setupListeners' import type { InternalMiddlewareState } from './buildMiddleware/types' +import { getOrInsertComputed } from '../utils' /** * `ifOlderThan` - (default: `false` | `number`) - _number is value in seconds_ @@ -619,11 +621,17 @@ export const coreModule = ({ }) safeAssign(api.internalActions, sliceActions) - const internalState: InternalMiddlewareState = { - currentSubscriptions: new Map(), - currentPolls: new Map(), - runningQueries: new Map(), - runningMutations: new Map(), + const internalStateMap = new WeakMap() + + const getInternalState = (dispatch: Dispatch) => { + const state = getOrInsertComputed(internalStateMap, dispatch, () => ({ + currentSubscriptions: new Map(), + currentPolls: new Map(), + runningQueries: new Map(), + runningMutations: new Map(), + })) + + return state } const { @@ -641,7 +649,7 @@ export const coreModule = ({ api, serializeQueryArgs: serializeQueryArgs as any, context, - internalState, + getInternalState, }) safeAssign(api.util, { @@ -661,7 +669,7 @@ export const coreModule = ({ assertTagType, selectors, getRunningQueryThunk, - internalState, + getInternalState, }) safeAssign(api.util, middlewareActions) diff --git a/packages/toolkit/src/query/tests/buildMiddleware.test.tsx b/packages/toolkit/src/query/tests/buildMiddleware.test.tsx index 2abac64498..3fb27f219e 100644 --- a/packages/toolkit/src/query/tests/buildMiddleware.test.tsx +++ b/packages/toolkit/src/query/tests/buildMiddleware.test.tsx @@ -1,6 +1,7 @@ import { createApi } from '@reduxjs/toolkit/query' import { delay } from 'msw' import { actionsReducer, setupApiStore } from '../../tests/utils/helpers' +import { vi } from 'vitest' const baseQuery = (args?: any) => ({ data: args }) const api = createApi({ @@ -213,3 +214,53 @@ it('correctly stringifies subscription state and dispatches subscriptionsUpdated subscriptionState['getBananas(undefined)']?.[subscription3.requestId], ).toEqual({}) }) + +it('does not leak subscription state between multiple stores using the same API instance (SSR scenario)', async () => { + vi.useFakeTimers() + // Simulate SSR: create API once at module level + const sharedApi = createApi({ + baseQuery: (args?: any) => ({ data: args }), + tagTypes: ['Test'], + endpoints: (build) => ({ + getTest: build.query({ + query(id) { + return { url: `test/${id}` } + }, + }), + }), + }) + + // Create first store (simulating first SSR request) + const store1Ref = setupApiStore(sharedApi, {}, { withoutListeners: true }) + + // Add subscription in store1 + const sub1 = store1Ref.store.dispatch( + sharedApi.endpoints.getTest.initiate(1, { + subscriptionOptions: { pollingInterval: 1000 }, + }), + ) + vi.advanceTimersByTime(10) + await sub1 + + // Wait for subscription sync (500ms + buffer) + vi.advanceTimersByTime(600) + + // Verify store1 has the subscription + const store1SubscriptionSelectors = store1Ref.store.dispatch( + sharedApi.internalActions.internal_getRTKQSubscriptions(), + ) as any + const store1InternalSubs = store1SubscriptionSelectors.getSubscriptions() + expect(store1InternalSubs.size).toBe(1) + + // Create second store (simulating second SSR request) + const store2Ref = setupApiStore(sharedApi, {}, { withoutListeners: true }) + + // Check subscriptions via internal action + const store2SubscriptionSelectors = store2Ref.store.dispatch( + sharedApi.internalActions.internal_getRTKQSubscriptions(), + ) as any + + const store2InternalSubs = store2SubscriptionSelectors.getSubscriptions() + + expect(store2InternalSubs.size).toBe(0) +}) diff --git a/packages/toolkit/src/query/utils/isNotNullish.ts b/packages/toolkit/src/query/utils/isNotNullish.ts index e2d8f4b172..1471f9e8cf 100644 --- a/packages/toolkit/src/query/utils/isNotNullish.ts +++ b/packages/toolkit/src/query/utils/isNotNullish.ts @@ -1,3 +1,7 @@ export function isNotNullish(v: T | null | undefined): v is T { return v != null } + +export function filterNullishValues(map?: Map) { + return [...(map?.values() ?? [])].filter(isNotNullish) as NonNullable[] +}