Skip to content

Commit 6385a36

Browse files
fix(realtime): preserve custom JWT tokens across channel resubscribe
Fixes #1904 When using setAuth(customToken) with private channels, custom JWTs are now preserved across removeChannel() and resubscribe operations. Previously, the token would be overwritten with session token or anon key. Root cause: setAuth() calls after connection and successful join were invoking the accessToken callback without checking if a custom token was manually set, causing SupabaseClient's _getAccessToken to return the wrong token. Solution: Track manually-set tokens with _manuallySetToken flag. Only invoke the accessToken callback when the token wasn't explicitly provided via setAuth(token). Changes: - Add _manuallySetToken flag to RealtimeClient - Update _performAuth() to track token source (manual vs callback) - Modify _setAuthSafely() to check flag before invoking callback - Update join callback in RealtimeChannel to check flag - Add error handling for accessToken callback failures - Add comprehensive regression tests (4 new tests) - Update existing tests for async subscribe Testing: - All 364 tests passing, zero regressions - Verified in React Native/Expo environment - Both setAuth(token) and accessToken callback patterns work - Workaround (accessToken callback) is now obsolete but remains supported Breaking changes: None
1 parent 8d1e0c5 commit 6385a36

File tree

6 files changed

+260
-20
lines changed

6 files changed

+260
-20
lines changed

packages/core/realtime-js/src/RealtimeChannel.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,10 @@ export default class RealtimeChannel {
308308

309309
this.joinPush
310310
.receive('ok', async ({ postgres_changes }: PostgresChangesFilters) => {
311-
this.socket.setAuth()
311+
// Only refresh auth if using callback-based tokens
312+
if (!this.socket._isManualToken()) {
313+
this.socket.setAuth()
314+
}
312315
if (postgres_changes === undefined) {
313316
callback?.(REALTIME_SUBSCRIBE_STATES.SUBSCRIBED)
314317
return
@@ -531,7 +534,7 @@ export default class RealtimeChannel {
531534
'channel',
532535
`resubscribe to ${this.topic} due to change in presence callbacks on joined channel`
533536
)
534-
this.unsubscribe().then(() => this.subscribe())
537+
this.unsubscribe().then(async () => await this.subscribe())
535538
}
536539
return this._on(type, filter, callback)
537540
}

packages/core/realtime-js/src/RealtimeClient.ts

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ const WORKER_SCRIPT = `
102102
export default class RealtimeClient {
103103
accessTokenValue: string | null = null
104104
apiKey: string | null = null
105+
private _manuallySetToken: boolean = false
105106
channels: RealtimeChannel[] = new Array()
106107
endPoint: string = ''
107108
httpEndpoint: string = ''
@@ -416,7 +417,18 @@ export default class RealtimeClient {
416417
*
417418
* On callback used, it will set the value of the token internal to the client.
418419
*
420+
* When a token is explicitly provided, it will be preserved across channel operations
421+
* (including removeChannel and resubscribe). The `accessToken` callback will not be
422+
* invoked until `setAuth()` is called without arguments.
423+
*
419424
* @param token A JWT string to override the token set on the client.
425+
*
426+
* @example
427+
* // Use a manual token (preserved across resubscribes, ignores accessToken callback)
428+
* client.realtime.setAuth('my-custom-jwt')
429+
*
430+
* // Switch back to using the accessToken callback
431+
* client.realtime.setAuth()
420432
*/
421433
async setAuth(token: string | null = null): Promise<void> {
422434
this._authPromise = this._performAuth(token)
@@ -426,6 +438,16 @@ export default class RealtimeClient {
426438
this._authPromise = null
427439
}
428440
}
441+
442+
/**
443+
* Returns true if the current access token was explicitly set via setAuth(token),
444+
* false if it was obtained via the accessToken callback.
445+
* @internal
446+
*/
447+
_isManualToken(): boolean {
448+
return this._manuallySetToken
449+
}
450+
429451
/**
430452
* Sends a heartbeat message if the socket is connected.
431453
*/
@@ -779,16 +801,33 @@ export default class RealtimeClient {
779801
*/
780802
private async _performAuth(token: string | null = null): Promise<void> {
781803
let tokenToSend: string | null
804+
let isManualToken = false
782805

783806
if (token) {
784807
tokenToSend = token
808+
// Track if this is a manually-provided token
809+
isManualToken = true
785810
} else if (this.accessToken) {
786-
// Always call the accessToken callback to get fresh token
787-
tokenToSend = await this.accessToken()
811+
// Call the accessToken callback to get fresh token
812+
try {
813+
tokenToSend = await this.accessToken()
814+
} catch (e) {
815+
this.log('error', 'Error fetching access token from callback', e)
816+
// Fall back to cached value if callback fails
817+
tokenToSend = this.accessTokenValue
818+
}
788819
} else {
789820
tokenToSend = this.accessTokenValue
790821
}
791822

823+
// Track whether this token was manually set or fetched via callback
824+
if (isManualToken) {
825+
this._manuallySetToken = true
826+
} else if (this.accessToken) {
827+
// If we used the callback, clear the manual flag
828+
this._manuallySetToken = false
829+
}
830+
792831
if (this.accessTokenValue != tokenToSend) {
793832
this.accessTokenValue = tokenToSend
794833
this.channels.forEach((channel) => {
@@ -823,9 +862,12 @@ export default class RealtimeClient {
823862
* @internal
824863
*/
825864
private _setAuthSafely(context = 'general'): void {
826-
this.setAuth().catch((e) => {
827-
this.log('error', `error setting auth in ${context}`, e)
828-
})
865+
// Only refresh auth if using callback-based tokens
866+
if (!this._isManualToken()) {
867+
this.setAuth().catch((e) => {
868+
this.log('error', `error setting auth in ${context}`, e)
869+
})
870+
}
829871
}
830872

831873
/**

packages/core/realtime-js/test/RealtimeChannel.lifecycle.test.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -229,10 +229,10 @@ describe('Channel Lifecycle Management', () => {
229229
assert.equal(channel.state, CHANNEL_STATES.joining)
230230
})
231231

232-
test('updates join push payload access token', () => {
232+
test('updates join push payload access token', async () => {
233233
testSetup.socket.accessTokenValue = 'token123'
234234

235-
channel.subscribe()
235+
await channel.subscribe()
236236

237237
assert.deepEqual(channel.joinPush.payload, {
238238
access_token: 'token123',
@@ -257,15 +257,15 @@ describe('Channel Lifecycle Management', () => {
257257
})
258258
const channel = testSocket.channel('topic')
259259

260-
channel.subscribe()
260+
await channel.subscribe()
261261
await new Promise((resolve) => setTimeout(resolve, 50))
262262
assert.equal(channel.socket.accessTokenValue, tokens[0])
263263

264264
testSocket.disconnect()
265265
// Wait for disconnect to complete (including fallback timer)
266266
await new Promise((resolve) => setTimeout(resolve, 150))
267267

268-
channel.subscribe()
268+
await channel.subscribe()
269269
await new Promise((resolve) => setTimeout(resolve, 50))
270270
assert.equal(channel.socket.accessTokenValue, tokens[1])
271271
})
@@ -549,7 +549,7 @@ describe('Channel Lifecycle Management', () => {
549549
const resendSpy = vi.spyOn(channel.joinPush, 'resend')
550550

551551
// Call _rejoin - should return early due to leaving state
552-
channel._rejoin()
552+
channel['_rejoin']()
553553

554554
// Verify no actions were taken
555555
expect(leaveOpenTopicSpy).not.toHaveBeenCalled()
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import assert from 'assert'
2+
import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest'
3+
import { testBuilders, EnhancedTestSetup } from './helpers/setup'
4+
import { utils } from './helpers/auth'
5+
import { CHANNEL_STATES } from '../src/lib/constants'
6+
7+
let testSetup: EnhancedTestSetup
8+
9+
beforeEach(() => {
10+
testSetup = testBuilders.standardClient()
11+
})
12+
13+
afterEach(() => {
14+
testSetup.cleanup()
15+
testSetup.socket.removeAllChannels()
16+
})
17+
18+
describe('Custom JWT token preservation', () => {
19+
test('preserves access token when resubscribing after removeChannel', async () => {
20+
// Test scenario:
21+
// 1. Set custom JWT via setAuth (not using accessToken callback)
22+
// 2. Subscribe to private channel
23+
// 3. removeChannel
24+
// 4. Create new channel with same topic and subscribe
25+
26+
const customToken = utils.generateJWT('1h')
27+
28+
// Step 1: Set auth with custom token (mimics user's setup)
29+
await testSetup.socket.setAuth(customToken)
30+
31+
// Verify token was set
32+
assert.strictEqual(testSetup.socket.accessTokenValue, customToken)
33+
34+
// Step 2: Create and subscribe to private channel (first time)
35+
const channel1 = testSetup.socket.channel('conversation:dc3fb8c1-ceef-4c00-9f92-e496acd03593', {
36+
config: { private: true },
37+
})
38+
39+
// Spy on the push to verify join payload
40+
const pushSpy = vi.spyOn(testSetup.socket, 'push')
41+
42+
// Simulate successful subscription
43+
channel1.state = CHANNEL_STATES.closed // Start from closed
44+
await channel1.subscribe()
45+
46+
// Verify first join includes access_token
47+
const firstJoinCall = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join')
48+
expect(firstJoinCall).toBeDefined()
49+
expect(firstJoinCall![0].payload).toHaveProperty('access_token', customToken)
50+
51+
// Step 3: Remove channel (mimics user cleanup)
52+
await testSetup.socket.removeChannel(channel1)
53+
54+
// Verify channel was removed
55+
expect(testSetup.socket.getChannels()).not.toContain(channel1)
56+
57+
// Step 4: Create NEW channel with SAME topic and subscribe
58+
pushSpy.mockClear()
59+
const channel2 = testSetup.socket.channel('conversation:dc3fb8c1-ceef-4c00-9f92-e496acd03593', {
60+
config: { private: true },
61+
})
62+
63+
// This should be a different channel instance
64+
expect(channel2).not.toBe(channel1)
65+
66+
// Subscribe to the new channel
67+
channel2.state = CHANNEL_STATES.closed
68+
await channel2.subscribe()
69+
70+
// Verify second join also includes access token
71+
const secondJoinCall = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join')
72+
73+
expect(secondJoinCall).toBeDefined()
74+
expect(secondJoinCall![0].payload).toHaveProperty('access_token', customToken)
75+
})
76+
77+
test('supports accessToken callback for token rotation', async () => {
78+
// Verify that callback-based token fetching works correctly
79+
const customToken = utils.generateJWT('1h')
80+
let callCount = 0
81+
82+
const clientWithCallback = testBuilders.standardClient({
83+
accessToken: async () => {
84+
callCount++
85+
return customToken
86+
},
87+
})
88+
89+
// Set initial auth
90+
await clientWithCallback.socket.setAuth()
91+
92+
// Create and subscribe to first channel
93+
const channel1 = clientWithCallback.socket.channel('conversation:test', {
94+
config: { private: true },
95+
})
96+
97+
const pushSpy = vi.spyOn(clientWithCallback.socket, 'push')
98+
channel1.state = CHANNEL_STATES.closed
99+
await channel1.subscribe()
100+
101+
const firstJoin = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join')
102+
expect(firstJoin![0].payload).toHaveProperty('access_token', customToken)
103+
104+
// Remove and recreate
105+
await clientWithCallback.socket.removeChannel(channel1)
106+
pushSpy.mockClear()
107+
108+
const channel2 = clientWithCallback.socket.channel('conversation:test', {
109+
config: { private: true },
110+
})
111+
112+
channel2.state = CHANNEL_STATES.closed
113+
await channel2.subscribe()
114+
115+
const secondJoin = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join')
116+
117+
// Callback should provide token for both subscriptions
118+
expect(secondJoin![0].payload).toHaveProperty('access_token', customToken)
119+
120+
clientWithCallback.cleanup()
121+
})
122+
123+
test('preserves token when subscribing to different topics', async () => {
124+
const customToken = utils.generateJWT('1h')
125+
await testSetup.socket.setAuth(customToken)
126+
127+
// Subscribe to first topic
128+
const channel1 = testSetup.socket.channel('topic1', { config: { private: true } })
129+
channel1.state = CHANNEL_STATES.closed
130+
await channel1.subscribe()
131+
132+
await testSetup.socket.removeChannel(channel1)
133+
134+
// Subscribe to DIFFERENT topic
135+
const pushSpy = vi.spyOn(testSetup.socket, 'push')
136+
const channel2 = testSetup.socket.channel('topic2', { config: { private: true } })
137+
channel2.state = CHANNEL_STATES.closed
138+
await channel2.subscribe()
139+
140+
const joinCall = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join')
141+
expect(joinCall![0].payload).toHaveProperty('access_token', customToken)
142+
})
143+
144+
test('handles accessToken callback errors gracefully during subscribe', async () => {
145+
const errorMessage = 'Token fetch failed during subscribe'
146+
let callCount = 0
147+
const tokens = ['initial-token', null] // Second call will throw
148+
149+
const accessToken = vi.fn(() => {
150+
if (callCount++ === 0) {
151+
return Promise.resolve(tokens[0])
152+
}
153+
return Promise.reject(new Error(errorMessage))
154+
})
155+
156+
const logSpy = vi.fn()
157+
158+
const client = testBuilders.standardClient({
159+
accessToken,
160+
logger: logSpy,
161+
})
162+
163+
// First subscribe should work
164+
await client.socket.setAuth()
165+
const channel1 = client.socket.channel('test', { config: { private: true } })
166+
channel1.state = CHANNEL_STATES.closed
167+
await channel1.subscribe()
168+
169+
expect(client.socket.accessTokenValue).toBe(tokens[0])
170+
171+
// Remove and resubscribe - callback will fail but should fall back
172+
await client.socket.removeChannel(channel1)
173+
174+
const channel2 = client.socket.channel('test', { config: { private: true } })
175+
channel2.state = CHANNEL_STATES.closed
176+
await channel2.subscribe()
177+
178+
// Verify error was logged
179+
expect(logSpy).toHaveBeenCalledWith(
180+
'error',
181+
'Error fetching access token from callback',
182+
expect.any(Error)
183+
)
184+
185+
// Verify subscription still succeeded with cached token
186+
expect(client.socket.accessTokenValue).toBe(tokens[0])
187+
188+
client.cleanup()
189+
})
190+
})

packages/core/realtime-js/test/RealtimeClient.auth.test.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,12 @@ describe('auth during connection states', () => {
140140

141141
await new Promise((resolve) => setTimeout(() => resolve(undefined), 100))
142142

143-
// Verify that the error was logged
144-
expect(logSpy).toHaveBeenCalledWith('error', 'error setting auth in connect', expect.any(Error))
143+
// Verify that the error was logged with more specific message
144+
expect(logSpy).toHaveBeenCalledWith(
145+
'error',
146+
'Error fetching access token from callback',
147+
expect.any(Error)
148+
)
145149

146150
// Verify that the connection was still established despite the error
147151
assert.ok(socketWithError.conn, 'connection should still exist')
@@ -199,7 +203,7 @@ describe('auth during connection states', () => {
199203
expect(socket.accessTokenValue).toBe(tokens[0])
200204

201205
// Call the callback and wait for async operations to complete
202-
await socket.reconnectTimer.callback()
206+
await socket.reconnectTimer?.callback()
203207
await new Promise((resolve) => setTimeout(resolve, 100))
204208
expect(socket.accessTokenValue).toBe(tokens[1])
205209
expect(accessToken).toHaveBeenCalledTimes(2)

packages/core/realtime-js/test/RealtimeClient.channels.test.ts

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ describe('channel', () => {
104104
const connectStub = vi.spyOn(testSetup.socket, 'connect')
105105
const disconnectStub = vi.spyOn(testSetup.socket, 'disconnect')
106106

107-
channel = testSetup.socket.channel('topic').subscribe()
107+
channel = testSetup.socket.channel('topic')
108+
await channel.subscribe()
108109

109110
assert.equal(testSetup.socket.getChannels().length, 1)
110111
expect(connectStub).toHaveBeenCalled()
@@ -118,11 +119,11 @@ describe('channel', () => {
118119
test('does not remove other channels when removing one', async () => {
119120
const connectStub = vi.spyOn(testSetup.socket, 'connect')
120121
const disconnectStub = vi.spyOn(testSetup.socket, 'disconnect')
121-
const channel1 = testSetup.socket.channel('chan1').subscribe()
122-
const channel2 = testSetup.socket.channel('chan2').subscribe()
122+
const channel1 = testSetup.socket.channel('chan1')
123+
const channel2 = testSetup.socket.channel('chan2')
123124

124-
channel1.subscribe()
125-
channel2.subscribe()
125+
await channel1.subscribe()
126+
await channel2.subscribe()
126127
assert.equal(testSetup.socket.getChannels().length, 2)
127128
expect(connectStub).toHaveBeenCalled()
128129

0 commit comments

Comments
 (0)