diff --git a/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts b/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts index 2f1024d13b..050b70dba1 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts @@ -1,5 +1,5 @@ import type { QueryDefinition } from '../../endpointDefinitions' -import type { ConfigState, QueryCacheKey } from '../apiState' +import type { ConfigState, QueryCacheKey, QuerySubState } from '../apiState' import { isAnyOf } from '../rtkImports' import type { ApiMiddlewareInternalHandler, @@ -11,16 +11,6 @@ import type { export type ReferenceCacheCollection = never -function isObjectEmpty(obj: Record) { - // Apparently a for..in loop is faster than `Object.keys()` here: - // https://stackoverflow.com/a/59787784/62937 - for (const k in obj) { - // If there is at least one key, it's not empty - return false - } - return true -} - export type CacheCollectionQueryExtraOptions = { /** * Overrides the api-wide definition of `keepUnusedDataFor` for this endpoint only. _(This value is in seconds.)_ @@ -44,6 +34,7 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({ context, internalState, selectors: { selectQueryEntry, selectConfig }, + getRunningQueryThunk, }) => { const { removeQueryResult, unsubscribeQueryResult, cacheEntriesUpserted } = api.internalActions @@ -57,7 +48,18 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({ function anySubscriptionsRemainingForKey(queryCacheKey: string) { const subscriptions = internalState.currentSubscriptions[queryCacheKey] - return !!subscriptions && !isObjectEmpty(subscriptions) + if (!subscriptions) { + return false + } + + // Check if there are any keys that are NOT _running subscriptions + for (const key in subscriptions) { + if (!key.endsWith('_running')) { + return true + } + } + // Only _running subscriptions remain (or empty) + return false } const currentRemovalTimeouts: QueryStateMeta = {} @@ -69,6 +71,7 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({ ) => { const state = mwApi.getState() const config = selectConfig(state) + if (canTriggerUnsubscribe(action)) { let queryCacheKeys: QueryCacheKey[] @@ -114,18 +117,20 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({ const state = api.getState() for (const queryCacheKey of cacheKeys) { const entry = selectQueryEntry(state, queryCacheKey) - handleUnsubscribe(queryCacheKey, entry?.endpointName, api, config) + if (entry?.endpointName) { + handleUnsubscribe(queryCacheKey, entry.endpointName, api, config) + } } } function handleUnsubscribe( queryCacheKey: QueryCacheKey, - endpointName: string | undefined, + endpointName: string, api: SubMiddlewareApi, config: ConfigState, ) { const endpointDefinition = context.endpointDefinitions[ - endpointName! + endpointName ] as QueryDefinition const keepUnusedDataFor = endpointDefinition?.keepUnusedDataFor ?? config.keepUnusedDataFor @@ -151,6 +156,15 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({ currentRemovalTimeouts[queryCacheKey] = setTimeout(() => { if (!anySubscriptionsRemainingForKey(queryCacheKey)) { + // Try to abort any running query for this cache key + const entry = selectQueryEntry(api.getState(), queryCacheKey) + + if (entry?.endpointName) { + const runningQuery = api.dispatch( + getRunningQueryThunk(entry.endpointName, entry.originalArgs), + ) + runningQuery?.abort() + } api.dispatch(removeQueryResult({ queryCacheKey })) } delete currentRemovalTimeouts![queryCacheKey] diff --git a/packages/toolkit/src/query/core/buildMiddleware/types.ts b/packages/toolkit/src/query/core/buildMiddleware/types.ts index a95acd2bf9..2db9cb775d 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/types.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/types.ts @@ -1,6 +1,7 @@ import type { Action, AsyncThunkAction, + Dispatch, Middleware, MiddlewareAPI, ThunkAction, @@ -54,6 +55,10 @@ export interface BuildMiddlewareInput< api: Api assertTagType: AssertTagTypes selectors: AllSelectors + getRunningQueryThunk: ( + endpointName: string, + queryArgs: any, + ) => (dispatch: Dispatch) => QueryActionCreatorResult | undefined } export type SubMiddlewareApi = MiddlewareAPI< diff --git a/packages/toolkit/src/query/core/module.ts b/packages/toolkit/src/query/core/module.ts index 41dafa1e7b..4133203e13 100644 --- a/packages/toolkit/src/query/core/module.ts +++ b/packages/toolkit/src/query/core/module.ts @@ -618,20 +618,6 @@ export const coreModule = ({ }) safeAssign(api.internalActions, sliceActions) - const { middleware, actions: middlewareActions } = buildMiddleware({ - reducerPath, - context, - queryThunk, - mutationThunk, - infiniteQueryThunk, - api, - assertTagType, - selectors, - }) - safeAssign(api.util, middlewareActions) - - safeAssign(api, { reducer: reducer as any, middleware }) - const { buildInitiateQuery, buildInitiateInfiniteQuery, @@ -656,6 +642,21 @@ export const coreModule = ({ getRunningQueriesThunk, }) + const { middleware, actions: middlewareActions } = buildMiddleware({ + reducerPath, + context, + queryThunk, + mutationThunk, + infiniteQueryThunk, + api, + assertTagType, + selectors, + getRunningQueryThunk, + }) + safeAssign(api.util, middlewareActions) + + safeAssign(api, { reducer: reducer as any, middleware }) + return { name: coreModuleName, injectEndpoint(endpointName, definition) { diff --git a/packages/toolkit/src/query/tests/buildHooks.test.tsx b/packages/toolkit/src/query/tests/buildHooks.test.tsx index 36e3b9d944..0a72ce8a58 100644 --- a/packages/toolkit/src/query/tests/buildHooks.test.tsx +++ b/packages/toolkit/src/query/tests/buildHooks.test.tsx @@ -1190,6 +1190,87 @@ describe('hooks tests', () => { ).toBe(-1) }) + test('query thunk should be aborted when component unmounts and cache entry is removed', async () => { + let abortSignalFromQueryFn: AbortSignal | undefined + + const pokemonApi = createApi({ + baseQuery: fetchBaseQuery({ baseUrl: 'https://pokeapi.co/api/v2/' }), + endpoints: (builder) => ({ + getTest: builder.query({ + async queryFn(arg, { signal }) { + abortSignalFromQueryFn = signal + + // Simulate a long-running request that should be aborted + await new Promise((resolve, reject) => { + const timeout = setTimeout(resolve, 5000) + + signal.addEventListener('abort', () => { + clearTimeout(timeout) + reject(new Error('Aborted')) + }) + }) + + return { data: 'data!' } + }, + keepUnusedDataFor: 0.01, // Very short timeout (10ms) + }), + }), + }) + + const storeRef = setupApiStore(pokemonApi, undefined, { + withoutTestLifecycles: true, + }) + + function TestComponent() { + const { data, isFetching } = pokemonApi.endpoints.getTest.useQuery(1) + + return ( +
+
{String(isFetching)}
+
{data || 'no data'}
+
+ ) + } + + function App() { + const [showComponent, setShowComponent] = useState(true) + + return ( +
+ {showComponent && } + +
+ ) + } + + render(, { wrapper: storeRef.wrapper }) + + // Wait for the query to start + await waitFor(() => + expect(screen.getByTestId('isFetching').textContent).toBe('true'), + ) + + // Verify we have an abort signal + expect(abortSignalFromQueryFn).toBeDefined() + expect(abortSignalFromQueryFn!.aborted).toBe(false) + + // Unmount the component + fireEvent.click(screen.getByTestId('unmount')) + + // Wait for the cache entry to be removed (keepUnusedDataFor: 0.01s = 10ms) + await act(async () => { + await delay(100) + }) + + // The abort signal should now be aborted + expect(abortSignalFromQueryFn!.aborted).toBe(true) + }) + describe('Hook middleware requirements', () => { const consoleErrorSpy = vi .spyOn(console, 'error') @@ -1898,7 +1979,6 @@ describe('hooks tests', () => { const checkNumQueries = (count: number) => { const cacheEntries = Object.keys(storeRef.store.getState().api.queries) const queries = cacheEntries.length - //console.log('queries', queries, storeRef.store.getState().api.queries) expect(queries).toBe(count) } diff --git a/packages/toolkit/src/query/tests/retry.test.ts b/packages/toolkit/src/query/tests/retry.test.ts index 03365fd0af..b2233021b0 100644 --- a/packages/toolkit/src/query/tests/retry.test.ts +++ b/packages/toolkit/src/query/tests/retry.test.ts @@ -465,4 +465,87 @@ describe('configuration', () => { expect(baseBaseQuery).toHaveBeenCalledOnce() }) + + test('retryCondition receives abort signal and stops retrying when cache entry is removed', async () => { + let capturedSignal: AbortSignal | undefined + let retryAttempts = 0 + + const baseBaseQuery = vi.fn< + Parameters, + ReturnType + >() + + // Always return an error to trigger retries + baseBaseQuery.mockResolvedValue({ error: 'network error' }) + + let retryConditionCalled = false + + const baseQuery = retry(baseBaseQuery, { + retryCondition: (error, args, { attempt, baseQueryApi }) => { + retryConditionCalled = true + retryAttempts = attempt + capturedSignal = baseQueryApi.signal + + // Stop retrying if the signal is aborted + if (baseQueryApi.signal.aborted) { + return false + } + + // Otherwise, retry up to 10 times + return attempt <= 10 + }, + backoff: async () => { + // Short backoff for faster test + await new Promise((resolve) => setTimeout(resolve, 10)) + }, + }) + + const api = createApi({ + baseQuery, + endpoints: (build) => ({ + getTest: build.query({ + query: (id) => ({ url: `test/${id}` }), + keepUnusedDataFor: 0.01, // Very short timeout (10ms) + }), + }), + }) + + const storeRef = setupApiStore(api, undefined, { + withoutTestLifecycles: true, + }) + + // Start the query + const queryPromise = storeRef.store.dispatch( + api.endpoints.getTest.initiate(1), + ) + + // Wait for the first retry to happen so we capture the signal + await loopTimers(2) + + // Verify the retry condition was called and we have a signal + expect(retryConditionCalled).toBe(true) + expect(capturedSignal).toBeDefined() + expect(capturedSignal!.aborted).toBe(false) + + // Unsubscribe to trigger cache removal + queryPromise.unsubscribe() + + // Wait for the cache entry to be removed (keepUnusedDataFor: 0.01s = 10ms) + await vi.advanceTimersByTimeAsync(50) + + // Allow some time for more retries to potentially happen + await loopTimers(3) + + // The signal should now be aborted + expect(capturedSignal!.aborted).toBe(true) + + // We should have stopped retrying early due to the abort signal + // If abort signal wasn't working, we'd see many more retry attempts + expect(retryAttempts).toBeLessThan(10) + + // The base query should have been called at least once (initial attempt) + // but not the full 10+ times it would without abort signal + expect(baseBaseQuery).toHaveBeenCalled() + expect(baseBaseQuery.mock.calls.length).toBeLessThan(10) + }) })