Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { QueryDefinition } from '../../endpointDefinitions'
import type { ConfigState, QueryCacheKey } from '../apiState'
import type { ConfigState, QueryCacheKey, QuerySubState } from '../apiState'
import { isAnyOf } from '../rtkImports'
import type {
ApiMiddlewareInternalHandler,
Expand All @@ -11,16 +11,6 @@ import type {

export type ReferenceCacheCollection = never

function isObjectEmpty(obj: Record<any, any>) {
// Apparently a for..in loop is faster than `Object.keys()` here:
// https://stackoverflow.com/a/59787784/62937
for (const k in obj) {
// If there is at least one key, it's not empty
return false
}
return true
}

export type CacheCollectionQueryExtraOptions = {
/**
* Overrides the api-wide definition of `keepUnusedDataFor` for this endpoint only. _(This value is in seconds.)_
Expand All @@ -44,6 +34,7 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({
context,
internalState,
selectors: { selectQueryEntry, selectConfig },
getRunningQueryThunk,
}) => {
const { removeQueryResult, unsubscribeQueryResult, cacheEntriesUpserted } =
api.internalActions
Expand All @@ -57,7 +48,18 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({

function anySubscriptionsRemainingForKey(queryCacheKey: string) {
const subscriptions = internalState.currentSubscriptions[queryCacheKey]
return !!subscriptions && !isObjectEmpty(subscriptions)
if (!subscriptions) {
return false
}

// Check if there are any keys that are NOT _running subscriptions
for (const key in subscriptions) {
if (!key.endsWith('_running')) {
return true
}
}
// Only _running subscriptions remain (or empty)
return false
}

const currentRemovalTimeouts: QueryStateMeta<TimeoutId> = {}
Expand All @@ -69,6 +71,7 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({
) => {
const state = mwApi.getState()
const config = selectConfig(state)

if (canTriggerUnsubscribe(action)) {
let queryCacheKeys: QueryCacheKey[]

Expand Down Expand Up @@ -114,18 +117,20 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({
const state = api.getState()
for (const queryCacheKey of cacheKeys) {
const entry = selectQueryEntry(state, queryCacheKey)
handleUnsubscribe(queryCacheKey, entry?.endpointName, api, config)
if (entry?.endpointName) {
handleUnsubscribe(queryCacheKey, entry.endpointName, api, config)
}
}
}

function handleUnsubscribe(
queryCacheKey: QueryCacheKey,
endpointName: string | undefined,
endpointName: string,
api: SubMiddlewareApi,
config: ConfigState<string>,
) {
const endpointDefinition = context.endpointDefinitions[
endpointName!
endpointName
] as QueryDefinition<any, any, any, any>
const keepUnusedDataFor =
endpointDefinition?.keepUnusedDataFor ?? config.keepUnusedDataFor
Expand All @@ -151,6 +156,15 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({

currentRemovalTimeouts[queryCacheKey] = setTimeout(() => {
if (!anySubscriptionsRemainingForKey(queryCacheKey)) {
// Try to abort any running query for this cache key
const entry = selectQueryEntry(api.getState(), queryCacheKey)

if (entry?.endpointName) {
const runningQuery = api.dispatch(
getRunningQueryThunk(entry.endpointName, entry.originalArgs),
)
runningQuery?.abort()
}
api.dispatch(removeQueryResult({ queryCacheKey }))
}
delete currentRemovalTimeouts![queryCacheKey]
Expand Down
5 changes: 5 additions & 0 deletions packages/toolkit/src/query/core/buildMiddleware/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type {
Action,
AsyncThunkAction,
Dispatch,
Middleware,
MiddlewareAPI,
ThunkAction,
Expand Down Expand Up @@ -54,6 +55,10 @@ export interface BuildMiddlewareInput<
api: Api<any, Definitions, ReducerPath, TagTypes>
assertTagType: AssertTagTypes
selectors: AllSelectors
getRunningQueryThunk: (
endpointName: string,
queryArgs: any,
) => (dispatch: Dispatch) => QueryActionCreatorResult<any> | undefined
}

export type SubMiddlewareApi = MiddlewareAPI<
Expand Down
29 changes: 15 additions & 14 deletions packages/toolkit/src/query/core/module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -618,20 +618,6 @@ export const coreModule = ({
})
safeAssign(api.internalActions, sliceActions)

const { middleware, actions: middlewareActions } = buildMiddleware({
reducerPath,
context,
queryThunk,
mutationThunk,
infiniteQueryThunk,
api,
assertTagType,
selectors,
})
safeAssign(api.util, middlewareActions)

safeAssign(api, { reducer: reducer as any, middleware })

const {
buildInitiateQuery,
buildInitiateInfiniteQuery,
Expand All @@ -656,6 +642,21 @@ export const coreModule = ({
getRunningQueriesThunk,
})

const { middleware, actions: middlewareActions } = buildMiddleware({
reducerPath,
context,
queryThunk,
mutationThunk,
infiniteQueryThunk,
api,
assertTagType,
selectors,
getRunningQueryThunk,
})
safeAssign(api.util, middlewareActions)

safeAssign(api, { reducer: reducer as any, middleware })

return {
name: coreModuleName,
injectEndpoint(endpointName, definition) {
Expand Down
82 changes: 81 additions & 1 deletion packages/toolkit/src/query/tests/buildHooks.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,87 @@ describe('hooks tests', () => {
).toBe(-1)
})

test('query thunk should be aborted when component unmounts and cache entry is removed', async () => {
let abortSignalFromQueryFn: AbortSignal | undefined

const pokemonApi = createApi({
baseQuery: fetchBaseQuery({ baseUrl: 'https://pokeapi.co/api/v2/' }),
endpoints: (builder) => ({
getTest: builder.query<string, number>({
async queryFn(arg, { signal }) {
abortSignalFromQueryFn = signal

// Simulate a long-running request that should be aborted
await new Promise((resolve, reject) => {
const timeout = setTimeout(resolve, 5000)

signal.addEventListener('abort', () => {
clearTimeout(timeout)
reject(new Error('Aborted'))
})
})

return { data: 'data!' }
},
keepUnusedDataFor: 0.01, // Very short timeout (10ms)
}),
}),
})

const storeRef = setupApiStore(pokemonApi, undefined, {
withoutTestLifecycles: true,
})

function TestComponent() {
const { data, isFetching } = pokemonApi.endpoints.getTest.useQuery(1)

return (
<div>
<div data-testid="isFetching">{String(isFetching)}</div>
<div data-testid="data">{data || 'no data'}</div>
</div>
)
}

function App() {
const [showComponent, setShowComponent] = useState(true)

return (
<div>
{showComponent && <TestComponent />}
<button
data-testid="unmount"
onClick={() => setShowComponent(false)}
>
Unmount Component
</button>
</div>
)
}

render(<App />, { wrapper: storeRef.wrapper })

// Wait for the query to start
await waitFor(() =>
expect(screen.getByTestId('isFetching').textContent).toBe('true'),
)

// Verify we have an abort signal
expect(abortSignalFromQueryFn).toBeDefined()
expect(abortSignalFromQueryFn!.aborted).toBe(false)

// Unmount the component
fireEvent.click(screen.getByTestId('unmount'))

// Wait for the cache entry to be removed (keepUnusedDataFor: 0.01s = 10ms)
await act(async () => {
await delay(100)
})

// The abort signal should now be aborted
expect(abortSignalFromQueryFn!.aborted).toBe(true)
})

describe('Hook middleware requirements', () => {
const consoleErrorSpy = vi
.spyOn(console, 'error')
Expand Down Expand Up @@ -1898,7 +1979,6 @@ describe('hooks tests', () => {
const checkNumQueries = (count: number) => {
const cacheEntries = Object.keys(storeRef.store.getState().api.queries)
const queries = cacheEntries.length
//console.log('queries', queries, storeRef.store.getState().api.queries)

expect(queries).toBe(count)
}
Expand Down
83 changes: 83 additions & 0 deletions packages/toolkit/src/query/tests/retry.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -465,4 +465,87 @@ describe('configuration', () => {

expect(baseBaseQuery).toHaveBeenCalledOnce()
})

test('retryCondition receives abort signal and stops retrying when cache entry is removed', async () => {
let capturedSignal: AbortSignal | undefined
let retryAttempts = 0

const baseBaseQuery = vi.fn<
Parameters<BaseQueryFn>,
ReturnType<BaseQueryFn>
>()

// Always return an error to trigger retries
baseBaseQuery.mockResolvedValue({ error: 'network error' })

let retryConditionCalled = false

const baseQuery = retry(baseBaseQuery, {
retryCondition: (error, args, { attempt, baseQueryApi }) => {
retryConditionCalled = true
retryAttempts = attempt
capturedSignal = baseQueryApi.signal

// Stop retrying if the signal is aborted
if (baseQueryApi.signal.aborted) {
return false
}

// Otherwise, retry up to 10 times
return attempt <= 10
},
backoff: async () => {
// Short backoff for faster test
await new Promise((resolve) => setTimeout(resolve, 10))
},
})

const api = createApi({
baseQuery,
endpoints: (build) => ({
getTest: build.query<string, number>({
query: (id) => ({ url: `test/${id}` }),
keepUnusedDataFor: 0.01, // Very short timeout (10ms)
}),
}),
})

const storeRef = setupApiStore(api, undefined, {
withoutTestLifecycles: true,
})

// Start the query
const queryPromise = storeRef.store.dispatch(
api.endpoints.getTest.initiate(1),
)

// Wait for the first retry to happen so we capture the signal
await loopTimers(2)

// Verify the retry condition was called and we have a signal
expect(retryConditionCalled).toBe(true)
expect(capturedSignal).toBeDefined()
expect(capturedSignal!.aborted).toBe(false)

// Unsubscribe to trigger cache removal
queryPromise.unsubscribe()

// Wait for the cache entry to be removed (keepUnusedDataFor: 0.01s = 10ms)
await vi.advanceTimersByTimeAsync(50)

// Allow some time for more retries to potentially happen
await loopTimers(3)

// The signal should now be aborted
expect(capturedSignal!.aborted).toBe(true)

// We should have stopped retrying early due to the abort signal
// If abort signal wasn't working, we'd see many more retry attempts
expect(retryAttempts).toBeLessThan(10)

// The base query should have been called at least once (initial attempt)
// but not the full 10+ times it would without abort signal
expect(baseBaseQuery).toHaveBeenCalled()
expect(baseBaseQuery.mock.calls.length).toBeLessThan(10)
})
})
Loading