Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 22 additions & 28 deletions packages/toolkit/src/query/core/buildInitiate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import {
type QueryDefinition,
type ResultTypeFrom,
} from '../endpointDefinitions'
import { countObjectKeys, getOrInsert, isNotNullish } from '../utils'
import { filterNullishValues } from '../utils'
import type {
InfiniteData,
InfiniteQueryConfigOptions,
Expand Down Expand Up @@ -271,17 +271,20 @@ export function buildInitiate({
mutationThunk,
api,
context,
internalState,
getInternalState,
}: {
serializeQueryArgs: InternalSerializeQueryArgs
queryThunk: QueryThunk
infiniteQueryThunk: InfiniteQueryThunk<any>
mutationThunk: MutationThunk
api: Api<any, EndpointDefinitions, any, any>
context: ApiContext<EndpointDefinitions>
internalState: InternalMiddlewareState
getInternalState: (dispatch: Dispatch) => InternalMiddlewareState
}) {
const { runningQueries, runningMutations } = internalState
const getRunningQueries = (dispatch: Dispatch) =>
getInternalState(dispatch)?.runningQueries
const getRunningMutations = (dispatch: Dispatch) =>
getInternalState(dispatch)?.runningMutations

const {
unsubscribeQueryResult,
Expand All @@ -306,7 +309,7 @@ export function buildInitiate({
endpointDefinition,
endpointName,
})
return runningQueries.get(dispatch)?.[queryCacheKey] as
return getRunningQueries(dispatch)?.get(queryCacheKey) as
| QueryActionCreatorResult<never>
| InfiniteQueryActionCreatorResult<never>
| undefined
Expand All @@ -322,20 +325,20 @@ export function buildInitiate({
fixedCacheKeyOrRequestId: string,
) {
return (dispatch: Dispatch) => {
return runningMutations.get(dispatch)?.[fixedCacheKeyOrRequestId] as
return getRunningMutations(dispatch)?.get(fixedCacheKeyOrRequestId) as
| MutationActionCreatorResult<never>
| undefined
}
}

function getRunningQueriesThunk() {
return (dispatch: Dispatch) =>
Object.values(runningQueries.get(dispatch) || {}).filter(isNotNullish)
filterNullishValues(getRunningQueries(dispatch))
}

function getRunningMutationsThunk() {
return (dispatch: Dispatch) =>
Object.values(runningMutations.get(dispatch) || {}).filter(isNotNullish)
filterNullishValues(getRunningMutations(dispatch))
}

function middlewareWarning(dispatch: Dispatch) {
Expand Down Expand Up @@ -429,7 +432,7 @@ You must add the middleware for RTK-Query to function correctly!`,

const skippedSynchronously = stateAfter.requestId !== requestId

const runningQuery = runningQueries.get(dispatch)?.[queryCacheKey]
const runningQuery = getRunningQueries(dispatch)?.get(queryCacheKey)
const selectFromState = () => selector(getState())

const statePromise: AnyActionCreatorResult = Object.assign(
Expand Down Expand Up @@ -489,14 +492,11 @@ You must add the middleware for RTK-Query to function correctly!`,
)

if (!runningQuery && !skippedSynchronously && !forceQueryFn) {
const running = getOrInsert(runningQueries, dispatch, {})
running[queryCacheKey] = statePromise
const runningQueries = getRunningQueries(dispatch)!
runningQueries.set(queryCacheKey, statePromise)

statePromise.then(() => {
delete running[queryCacheKey]
if (!countObjectKeys(running)) {
runningQueries.delete(dispatch)
}
runningQueries.delete(queryCacheKey)
})
}

Expand Down Expand Up @@ -559,23 +559,17 @@ You must add the middleware for RTK-Query to function correctly!`,
reset,
})

const running = runningMutations.get(dispatch) || {}
runningMutations.set(dispatch, running)
running[requestId] = ret
const runningMutations = getRunningMutations(dispatch)!

runningMutations.set(requestId, ret)
ret.then(() => {
delete running[requestId]
if (!countObjectKeys(running)) {
runningMutations.delete(dispatch)
}
runningMutations.delete(requestId)
})
if (fixedCacheKey) {
running[fixedCacheKey] = ret
runningMutations.set(fixedCacheKey, ret)
ret.then(() => {
if (running[fixedCacheKey] === ret) {
delete running[fixedCacheKey]
if (!countObjectKeys(running)) {
runningMutations.delete(dispatch)
}
if (runningMutations.get(fixedCacheKey) === ret) {
runningMutations.delete(fixedCacheKey)
}
})
}
Expand Down
47 changes: 22 additions & 25 deletions packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,27 @@ export type ReferenceCacheCollection = never

/**
* @example
* ```ts
* // codeblock-meta title="keepUnusedDataFor example"
* import { createApi, fetchBaseQuery } from '@reduxjs/toolkit/query/react'
* interface Post {
* id: number
* name: string
* }
* type PostsResponse = Post[]
*
* const api = createApi({
* baseQuery: fetchBaseQuery({ baseUrl: '/' }),
* endpoints: (build) => ({
* getPosts: build.query<PostsResponse, void>({
* query: () => 'posts',
* // highlight-start
* keepUnusedDataFor: 5
* // highlight-end
* })
* })
* })
* ```
* ```ts
* // codeblock-meta title="keepUnusedDataFor example"
* import { createApi, fetchBaseQuery } from '@reduxjs/toolkit/query/react'
* interface Post {
* id: number
* name: string
* }
* type PostsResponse = Post[]
*
* const api = createApi({
* baseQuery: fetchBaseQuery({ baseUrl: '/' }),
* endpoints: (build) => ({
* getPosts: build.query<PostsResponse, void>({
* query: () => 'posts',
* // highlight-start
* keepUnusedDataFor: 5
* // highlight-end
* })
* })
* })
* ```
*/
export type CacheCollectionQueryExtraOptions = {
/**
Expand Down Expand Up @@ -64,8 +64,6 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({
const { removeQueryResult, unsubscribeQueryResult, cacheEntriesUpserted } =
api.internalActions

const runningQueries = internalState.runningQueries.get(mwApi.dispatch)!

const canTriggerUnsubscribe = isAnyOf(
unsubscribeQueryResult.match,
queryThunk.fulfilled,
Expand All @@ -80,8 +78,7 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({
}

const hasSubscriptions = subscriptions.size > 0
const isRunning = runningQueries?.[queryCacheKey] !== undefined
return hasSubscriptions || isRunning
return hasSubscriptions
}

const currentRemovalTimeouts: QueryStateMeta<TimeoutId> = {}
Expand Down
4 changes: 3 additions & 1 deletion packages/toolkit/src/query/core/buildMiddleware/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export function buildMiddleware<
ReducerPath extends string,
TagTypes extends string,
>(input: BuildMiddlewareInput<Definitions, ReducerPath, TagTypes>) {
const { reducerPath, queryThunk, api, context, internalState } = input
const { reducerPath, queryThunk, api, context, getInternalState } = input
const { apiUid } = context

const actions = {
Expand Down Expand Up @@ -73,6 +73,8 @@ export function buildMiddleware<
> = (mwApi) => {
let initialized = false

const internalState = getInternalState(mwApi.dispatch)

const builderArgs = {
...(input as any as BuildMiddlewareInput<
EndpointDefinitions,
Expand Down
18 changes: 6 additions & 12 deletions packages/toolkit/src/query/core/buildMiddleware/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,12 @@ export interface InternalMiddlewareState {
currentSubscriptions: SubscriptionInternalState
currentPolls: Map<string, QueryPollState>
runningQueries: Map<
Dispatch,
Record<
string,
| QueryActionCreatorResult<any>
| InfiniteQueryActionCreatorResult<any>
| undefined
>
>
runningMutations: Map<
Dispatch,
Record<string, MutationActionCreatorResult<any> | undefined>
string,
| QueryActionCreatorResult<any>
| InfiniteQueryActionCreatorResult<any>
| undefined
>
runningMutations: Map<string, MutationActionCreatorResult<any> | undefined>
}

export interface SubscriptionSelectors {
Expand All @@ -84,7 +78,7 @@ export interface BuildMiddlewareInput<
endpointName: string,
queryArgs: any,
) => (dispatch: Dispatch) => QueryActionCreatorResult<any> | undefined
internalState: InternalMiddlewareState
getInternalState: (dispatch: Dispatch) => InternalMiddlewareState
}

export type SubMiddlewareApi = MiddlewareAPI<
Expand Down
22 changes: 15 additions & 7 deletions packages/toolkit/src/query/core/module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*/
import type {
ActionCreatorWithPayload,
Dispatch,
Middleware,
Reducer,
ThunkAction,
Expand Down Expand Up @@ -72,6 +73,7 @@ import { buildThunks } from './buildThunks'
import { createSelector as _createSelector } from './rtkImports'
import { onFocus, onFocusLost, onOffline, onOnline } from './setupListeners'
import type { InternalMiddlewareState } from './buildMiddleware/types'
import { getOrInsertComputed } from '../utils'

/**
* `ifOlderThan` - (default: `false` | `number`) - _number is value in seconds_
Expand Down Expand Up @@ -619,11 +621,17 @@ export const coreModule = ({
})
safeAssign(api.internalActions, sliceActions)

const internalState: InternalMiddlewareState = {
currentSubscriptions: new Map(),
currentPolls: new Map(),
runningQueries: new Map(),
runningMutations: new Map(),
const internalStateMap = new WeakMap<Dispatch, InternalMiddlewareState>()

const getInternalState = (dispatch: Dispatch) => {
const state = getOrInsertComputed(internalStateMap, dispatch, () => ({
currentSubscriptions: new Map(),
currentPolls: new Map(),
runningQueries: new Map(),
runningMutations: new Map(),
}))

return state
}

const {
Expand All @@ -641,7 +649,7 @@ export const coreModule = ({
api,
serializeQueryArgs: serializeQueryArgs as any,
context,
internalState,
getInternalState,
})

safeAssign(api.util, {
Expand All @@ -661,7 +669,7 @@ export const coreModule = ({
assertTagType,
selectors,
getRunningQueryThunk,
internalState,
getInternalState,
})
safeAssign(api.util, middlewareActions)

Expand Down
51 changes: 51 additions & 0 deletions packages/toolkit/src/query/tests/buildMiddleware.test.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { createApi } from '@reduxjs/toolkit/query'
import { delay } from 'msw'
import { actionsReducer, setupApiStore } from '../../tests/utils/helpers'
import { vi } from 'vitest'

const baseQuery = (args?: any) => ({ data: args })
const api = createApi({
Expand Down Expand Up @@ -213,3 +214,53 @@ it('correctly stringifies subscription state and dispatches subscriptionsUpdated
subscriptionState['getBananas(undefined)']?.[subscription3.requestId],
).toEqual({})
})

it('does not leak subscription state between multiple stores using the same API instance (SSR scenario)', async () => {
vi.useFakeTimers()
// Simulate SSR: create API once at module level
const sharedApi = createApi({
baseQuery: (args?: any) => ({ data: args }),
tagTypes: ['Test'],
endpoints: (build) => ({
getTest: build.query<unknown, number>({
query(id) {
return { url: `test/${id}` }
},
}),
}),
})

// Create first store (simulating first SSR request)
const store1Ref = setupApiStore(sharedApi, {}, { withoutListeners: true })

// Add subscription in store1
const sub1 = store1Ref.store.dispatch(
sharedApi.endpoints.getTest.initiate(1, {
subscriptionOptions: { pollingInterval: 1000 },
}),
)
vi.advanceTimersByTime(10)
await sub1

// Wait for subscription sync (500ms + buffer)
vi.advanceTimersByTime(600)

// Verify store1 has the subscription
const store1SubscriptionSelectors = store1Ref.store.dispatch(
sharedApi.internalActions.internal_getRTKQSubscriptions(),
) as any
const store1InternalSubs = store1SubscriptionSelectors.getSubscriptions()
expect(store1InternalSubs.size).toBe(1)

// Create second store (simulating second SSR request)
const store2Ref = setupApiStore(sharedApi, {}, { withoutListeners: true })

// Check subscriptions via internal action
const store2SubscriptionSelectors = store2Ref.store.dispatch(
sharedApi.internalActions.internal_getRTKQSubscriptions(),
) as any

const store2InternalSubs = store2SubscriptionSelectors.getSubscriptions()

expect(store2InternalSubs.size).toBe(0)
})
4 changes: 4 additions & 0 deletions packages/toolkit/src/query/utils/isNotNullish.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
export function isNotNullish<T>(v: T | null | undefined): v is T {
return v != null
}

export function filterNullishValues<T>(map?: Map<any, T>) {
return [...(map?.values() ?? [])].filter(isNotNullish) as NonNullable<T>[]
}
Loading