Skip to content

Commit 139a745

Browse files
fix(realtime): preserve custom JWT tokens across channel resubscribe (#1908)
1 parent 66351aa commit 139a745

File tree

6 files changed

+259
-19
lines changed

6 files changed

+259
-19
lines changed

packages/core/realtime-js/src/RealtimeChannel.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,10 @@ export default class RealtimeChannel {
308308

309309
this.joinPush
310310
.receive('ok', async ({ postgres_changes }: PostgresChangesFilters) => {
311-
this.socket.setAuth()
311+
// Only refresh auth if using callback-based tokens
312+
if (!this.socket._isManualToken()) {
313+
this.socket.setAuth()
314+
}
312315
if (postgres_changes === undefined) {
313316
callback?.(REALTIME_SUBSCRIBE_STATES.SUBSCRIBED)
314317
return
@@ -531,7 +534,7 @@ export default class RealtimeChannel {
531534
'channel',
532535
`resubscribe to ${this.topic} due to change in presence callbacks on joined channel`
533536
)
534-
this.unsubscribe().then(() => this.subscribe())
537+
this.unsubscribe().then(async () => await this.subscribe())
535538
}
536539
return this._on(type, filter, callback)
537540
}

packages/core/realtime-js/src/RealtimeClient.ts

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ const WORKER_SCRIPT = `
102102
export default class RealtimeClient {
103103
accessTokenValue: string | null = null
104104
apiKey: string | null = null
105+
private _manuallySetToken: boolean = false
105106
channels: RealtimeChannel[] = new Array()
106107
endPoint: string = ''
107108
httpEndpoint: string = ''
@@ -416,7 +417,18 @@ export default class RealtimeClient {
416417
*
417418
* On callback used, it will set the value of the token internal to the client.
418419
*
420+
* When a token is explicitly provided, it will be preserved across channel operations
421+
* (including removeChannel and resubscribe). The `accessToken` callback will not be
422+
* invoked until `setAuth()` is called without arguments.
423+
*
419424
* @param token A JWT string to override the token set on the client.
425+
*
426+
* @example
427+
* // Use a manual token (preserved across resubscribes, ignores accessToken callback)
428+
* client.realtime.setAuth('my-custom-jwt')
429+
*
430+
* // Switch back to using the accessToken callback
431+
* client.realtime.setAuth()
420432
*/
421433
async setAuth(token: string | null = null): Promise<void> {
422434
this._authPromise = this._performAuth(token)
@@ -426,6 +438,16 @@ export default class RealtimeClient {
426438
this._authPromise = null
427439
}
428440
}
441+
442+
/**
443+
* Returns true if the current access token was explicitly set via setAuth(token),
444+
* false if it was obtained via the accessToken callback.
445+
* @internal
446+
*/
447+
_isManualToken(): boolean {
448+
return this._manuallySetToken
449+
}
450+
429451
/**
430452
* Sends a heartbeat message if the socket is connected.
431453
*/
@@ -779,16 +801,33 @@ export default class RealtimeClient {
779801
*/
780802
private async _performAuth(token: string | null = null): Promise<void> {
781803
let tokenToSend: string | null
804+
let isManualToken = false
782805

783806
if (token) {
784807
tokenToSend = token
808+
// Track if this is a manually-provided token
809+
isManualToken = true
785810
} else if (this.accessToken) {
786-
// Always call the accessToken callback to get fresh token
787-
tokenToSend = await this.accessToken()
811+
// Call the accessToken callback to get fresh token
812+
try {
813+
tokenToSend = await this.accessToken()
814+
} catch (e) {
815+
this.log('error', 'Error fetching access token from callback', e)
816+
// Fall back to cached value if callback fails
817+
tokenToSend = this.accessTokenValue
818+
}
788819
} else {
789820
tokenToSend = this.accessTokenValue
790821
}
791822

823+
// Track whether this token was manually set or fetched via callback
824+
if (isManualToken) {
825+
this._manuallySetToken = true
826+
} else if (this.accessToken) {
827+
// If we used the callback, clear the manual flag
828+
this._manuallySetToken = false
829+
}
830+
792831
if (this.accessTokenValue != tokenToSend) {
793832
this.accessTokenValue = tokenToSend
794833
this.channels.forEach((channel) => {
@@ -823,9 +862,12 @@ export default class RealtimeClient {
823862
* @internal
824863
*/
825864
private _setAuthSafely(context = 'general'): void {
826-
this.setAuth().catch((e) => {
827-
this.log('error', `error setting auth in ${context}`, e)
828-
})
865+
// Only refresh auth if using callback-based tokens
866+
if (!this._isManualToken()) {
867+
this.setAuth().catch((e) => {
868+
this.log('error', `Error setting auth in ${context}`, e)
869+
})
870+
}
829871
}
830872

831873
/**

packages/core/realtime-js/test/RealtimeChannel.lifecycle.test.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,10 +229,10 @@ describe('Channel Lifecycle Management', () => {
229229
assert.equal(channel.state, CHANNEL_STATES.joining)
230230
})
231231

232-
test('updates join push payload access token', () => {
232+
test('updates join push payload access token', async () => {
233233
testSetup.socket.accessTokenValue = 'token123'
234234

235-
channel.subscribe()
235+
await channel.subscribe()
236236

237237
assert.deepEqual(channel.joinPush.payload, {
238238
access_token: 'token123',
@@ -257,15 +257,15 @@ describe('Channel Lifecycle Management', () => {
257257
})
258258
const channel = testSocket.channel('topic')
259259

260-
channel.subscribe()
260+
await channel.subscribe()
261261
await new Promise((resolve) => setTimeout(resolve, 50))
262262
assert.equal(channel.socket.accessTokenValue, tokens[0])
263263

264264
testSocket.disconnect()
265265
// Wait for disconnect to complete (including fallback timer)
266266
await new Promise((resolve) => setTimeout(resolve, 150))
267267

268-
channel.subscribe()
268+
await channel.subscribe()
269269
await new Promise((resolve) => setTimeout(resolve, 50))
270270
assert.equal(channel.socket.accessTokenValue, tokens[1])
271271
})
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import assert from 'assert'
2+
import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest'
3+
import { testBuilders, EnhancedTestSetup } from './helpers/setup'
4+
import { utils } from './helpers/auth'
5+
import { CHANNEL_STATES } from '../src/lib/constants'
6+
7+
let testSetup: EnhancedTestSetup
8+
9+
beforeEach(() => {
10+
testSetup = testBuilders.standardClient()
11+
})
12+
13+
afterEach(() => {
14+
testSetup.cleanup()
15+
testSetup.socket.removeAllChannels()
16+
})
17+
18+
describe('Custom JWT token preservation', () => {
19+
test('preserves access token when resubscribing after removeChannel', async () => {
20+
// Test scenario:
21+
// 1. Set custom JWT via setAuth (not using accessToken callback)
22+
// 2. Subscribe to private channel
23+
// 3. removeChannel
24+
// 4. Create new channel with same topic and subscribe
25+
26+
const customToken = utils.generateJWT('1h')
27+
28+
// Step 1: Set auth with custom token (mimics user's setup)
29+
await testSetup.socket.setAuth(customToken)
30+
31+
// Verify token was set
32+
assert.strictEqual(testSetup.socket.accessTokenValue, customToken)
33+
34+
// Step 2: Create and subscribe to private channel (first time)
35+
const channel1 = testSetup.socket.channel('conversation:dc3fb8c1-ceef-4c00-9f92-e496acd03593', {
36+
config: { private: true },
37+
})
38+
39+
// Spy on the push to verify join payload
40+
const pushSpy = vi.spyOn(testSetup.socket, 'push')
41+
42+
// Simulate successful subscription
43+
channel1.state = CHANNEL_STATES.closed // Start from closed
44+
await channel1.subscribe()
45+
46+
// Verify first join includes access_token
47+
const firstJoinCall = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join')
48+
expect(firstJoinCall).toBeDefined()
49+
expect(firstJoinCall![0].payload).toHaveProperty('access_token', customToken)
50+
51+
// Step 3: Remove channel (mimics user cleanup)
52+
await testSetup.socket.removeChannel(channel1)
53+
54+
// Verify channel was removed
55+
expect(testSetup.socket.getChannels()).not.toContain(channel1)
56+
57+
// Step 4: Create NEW channel with SAME topic and subscribe
58+
pushSpy.mockClear()
59+
const channel2 = testSetup.socket.channel('conversation:dc3fb8c1-ceef-4c00-9f92-e496acd03593', {
60+
config: { private: true },
61+
})
62+
63+
// This should be a different channel instance
64+
expect(channel2).not.toBe(channel1)
65+
66+
// Subscribe to the new channel
67+
channel2.state = CHANNEL_STATES.closed
68+
await channel2.subscribe()
69+
70+
// Verify second join also includes access token
71+
const secondJoinCall = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join')
72+
73+
expect(secondJoinCall).toBeDefined()
74+
expect(secondJoinCall![0].payload).toHaveProperty('access_token', customToken)
75+
})
76+
77+
test('supports accessToken callback for token rotation', async () => {
78+
// Verify that callback-based token fetching works correctly
79+
const customToken = utils.generateJWT('1h')
80+
let callCount = 0
81+
82+
const clientWithCallback = testBuilders.standardClient({
83+
accessToken: async () => {
84+
callCount++
85+
return customToken
86+
},
87+
})
88+
89+
// Set initial auth
90+
await clientWithCallback.socket.setAuth()
91+
92+
// Create and subscribe to first channel
93+
const channel1 = clientWithCallback.socket.channel('conversation:test', {
94+
config: { private: true },
95+
})
96+
97+
const pushSpy = vi.spyOn(clientWithCallback.socket, 'push')
98+
channel1.state = CHANNEL_STATES.closed
99+
await channel1.subscribe()
100+
101+
const firstJoin = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join')
102+
expect(firstJoin![0].payload).toHaveProperty('access_token', customToken)
103+
104+
// Remove and recreate
105+
await clientWithCallback.socket.removeChannel(channel1)
106+
pushSpy.mockClear()
107+
108+
const channel2 = clientWithCallback.socket.channel('conversation:test', {
109+
config: { private: true },
110+
})
111+
112+
channel2.state = CHANNEL_STATES.closed
113+
await channel2.subscribe()
114+
115+
const secondJoin = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join')
116+
117+
// Callback should provide token for both subscriptions
118+
expect(secondJoin![0].payload).toHaveProperty('access_token', customToken)
119+
120+
clientWithCallback.cleanup()
121+
})
122+
123+
test('preserves token when subscribing to different topics', async () => {
124+
const customToken = utils.generateJWT('1h')
125+
await testSetup.socket.setAuth(customToken)
126+
127+
// Subscribe to first topic
128+
const channel1 = testSetup.socket.channel('topic1', { config: { private: true } })
129+
channel1.state = CHANNEL_STATES.closed
130+
await channel1.subscribe()
131+
132+
await testSetup.socket.removeChannel(channel1)
133+
134+
// Subscribe to DIFFERENT topic
135+
const pushSpy = vi.spyOn(testSetup.socket, 'push')
136+
const channel2 = testSetup.socket.channel('topic2', { config: { private: true } })
137+
channel2.state = CHANNEL_STATES.closed
138+
await channel2.subscribe()
139+
140+
const joinCall = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join')
141+
expect(joinCall![0].payload).toHaveProperty('access_token', customToken)
142+
})
143+
144+
test('handles accessToken callback errors gracefully during subscribe', async () => {
145+
const errorMessage = 'Token fetch failed during subscribe'
146+
let callCount = 0
147+
const tokens = ['initial-token', null] // Second call will throw
148+
149+
const accessToken = vi.fn(() => {
150+
if (callCount++ === 0) {
151+
return Promise.resolve(tokens[0])
152+
}
153+
return Promise.reject(new Error(errorMessage))
154+
})
155+
156+
const logSpy = vi.fn()
157+
158+
const client = testBuilders.standardClient({
159+
accessToken,
160+
logger: logSpy,
161+
})
162+
163+
// First subscribe should work
164+
await client.socket.setAuth()
165+
const channel1 = client.socket.channel('test', { config: { private: true } })
166+
channel1.state = CHANNEL_STATES.closed
167+
await channel1.subscribe()
168+
169+
expect(client.socket.accessTokenValue).toBe(tokens[0])
170+
171+
// Remove and resubscribe - callback will fail but should fall back
172+
await client.socket.removeChannel(channel1)
173+
174+
const channel2 = client.socket.channel('test', { config: { private: true } })
175+
channel2.state = CHANNEL_STATES.closed
176+
await channel2.subscribe()
177+
178+
// Verify error was logged
179+
expect(logSpy).toHaveBeenCalledWith(
180+
'error',
181+
'Error fetching access token from callback',
182+
expect.any(Error)
183+
)
184+
185+
// Verify subscription still succeeded with cached token
186+
expect(client.socket.accessTokenValue).toBe(tokens[0])
187+
188+
client.cleanup()
189+
})
190+
})

packages/core/realtime-js/test/RealtimeClient.auth.test.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,12 @@ describe('auth during connection states', () => {
140140

141141
await new Promise((resolve) => setTimeout(() => resolve(undefined), 100))
142142

143-
// Verify that the error was logged
144-
expect(logSpy).toHaveBeenCalledWith('error', 'error setting auth in connect', expect.any(Error))
143+
// Verify that the error was logged with more specific message
144+
expect(logSpy).toHaveBeenCalledWith(
145+
'error',
146+
'Error fetching access token from callback',
147+
expect.any(Error)
148+
)
145149

146150
// Verify that the connection was still established despite the error
147151
assert.ok(socketWithError.conn, 'connection should still exist')
@@ -199,7 +203,7 @@ describe('auth during connection states', () => {
199203
expect(socket.accessTokenValue).toBe(tokens[0])
200204

201205
// Call the callback and wait for async operations to complete
202-
await socket.reconnectTimer.callback()
206+
await socket.reconnectTimer?.callback()
203207
await new Promise((resolve) => setTimeout(resolve, 100))
204208
expect(socket.accessTokenValue).toBe(tokens[1])
205209
expect(accessToken).toHaveBeenCalledTimes(2)

packages/core/realtime-js/test/RealtimeClient.channels.test.ts

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ describe('channel', () => {
104104
const connectStub = vi.spyOn(testSetup.socket, 'connect')
105105
const disconnectStub = vi.spyOn(testSetup.socket, 'disconnect')
106106

107-
channel = testSetup.socket.channel('topic').subscribe()
107+
channel = testSetup.socket.channel('topic')
108+
await channel.subscribe()
108109

109110
assert.equal(testSetup.socket.getChannels().length, 1)
110111
expect(connectStub).toHaveBeenCalled()
@@ -118,11 +119,11 @@ describe('channel', () => {
118119
test('does not remove other channels when removing one', async () => {
119120
const connectStub = vi.spyOn(testSetup.socket, 'connect')
120121
const disconnectStub = vi.spyOn(testSetup.socket, 'disconnect')
121-
const channel1 = testSetup.socket.channel('chan1').subscribe()
122-
const channel2 = testSetup.socket.channel('chan2').subscribe()
122+
const channel1 = testSetup.socket.channel('chan1')
123+
const channel2 = testSetup.socket.channel('chan2')
123124

124-
channel1.subscribe()
125-
channel2.subscribe()
125+
await channel1.subscribe()
126+
await channel2.subscribe()
126127
assert.equal(testSetup.socket.getChannels().length, 2)
127128
expect(connectStub).toHaveBeenCalled()
128129

0 commit comments

Comments
 (0)