Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions lib/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uid2 = require("uid2");
import msgpack = require("notepack.io");
import { Adapter, BroadcastOptions, Room } from "socket.io-adapter";
import { PUBSUB } from "./util";
import { parseTimeout, PUBSUB } from "./util";

const debug = require("debug")("socket.io-redis");

Expand Down Expand Up @@ -36,6 +36,7 @@ interface Request {
interface AckRequest {
clientCountCallback: (clientCount: number) => void;
ack: (...args: any[]) => void;
timeout: NodeJS.Timeout;
}

interface Parser {
Expand Down Expand Up @@ -128,7 +129,7 @@ export class RedisAdapter extends Adapter {
super(nsp);

this.uid = uid2(6);
this.requestsTimeout = opts.requestsTimeout || 5000;
this.requestsTimeout = parseTimeout(opts.requestsTimeout, 5000);
this.publishOnSpecificResponseChannel =
!!opts.publishOnSpecificResponseChannel;
this.parser = opts.parser || msgpack;
Expand Down Expand Up @@ -185,6 +186,7 @@ export class RedisAdapter extends Adapter {
);
}

// Use function() instead of arrow function so 'this' refers to the Redis client (event emitter)
this.friendlyErrorHandler = function () {
if (this.listenerCount("error") === 1) {
console.warn("missing 'error' handler on this Redis client");
Expand Down Expand Up @@ -542,9 +544,9 @@ export class RedisAdapter extends Adapter {
request.msgCount++;

// ignore if response does not contain 'sockets' key
if (!response.sockets || !Array.isArray(response.sockets)) return;

if (request.type === RequestType.SOCKETS) {
if (!response.sockets || !Array.isArray(response.sockets)) {
debug("ignoring malformed response (missing sockets array)");
} else if (request.type === RequestType.SOCKETS) {
response.sockets.forEach((s) => request.sockets.add(s));
} else {
response.sockets.forEach((s) => request.sockets.push(s));
Expand All @@ -563,9 +565,11 @@ export class RedisAdapter extends Adapter {
request.msgCount++;

// ignore if response does not contain 'rooms' key
if (!response.rooms || !Array.isArray(response.rooms)) return;

response.rooms.forEach((s) => request.rooms.add(s));
if (!response.rooms || !Array.isArray(response.rooms)) {
debug("ignoring malformed response (missing rooms array)");
} else {
response.rooms.forEach((s) => request.rooms.add(s));
}

if (request.msgCount === request.numSub) {
clearTimeout(request.timeout);
Expand Down Expand Up @@ -667,16 +671,22 @@ export class RedisAdapter extends Adapter {

this.pubClient.publish(this.requestChannel, request);

// we have no way to know at this level whether the server has received an acknowledgement from each client, so we
// will simply clean up the ackRequests map after the given delay
const ackTimeout = parseTimeout(
opts.flags?.timeout,
this.requestsTimeout
);

const timeout = setTimeout(() => {
this.ackRequests.delete(requestId);
}, ackTimeout);

this.ackRequests.set(requestId, {
clientCountCallback,
ack,
timeout,
});

// we have no way to know at this level whether the server has received an acknowledgement from each client, so we
// will simply clean up the ackRequests map after the given delay
setTimeout(() => {
this.ackRequests.delete(requestId);
}, opts.flags!.timeout);
}

super.broadcastWithAck(packet, opts, clientCountCallback, ack);
Expand Down Expand Up @@ -895,6 +905,18 @@ export class RedisAdapter extends Adapter {
}

close(): Promise<void> | void {
// Cancel all pending request timeouts and clear the map
this.requests.forEach((request) => {
clearTimeout(request.timeout);
});
this.requests.clear();

// Cancel all pending ack request timeouts and clear the map
this.ackRequests.forEach((ackRequest) => {
clearTimeout(ackRequest.timeout);
});
this.ackRequests.clear();

const isRedisV4 = typeof this.pubClient.pSubscribe === "function";
if (isRedisV4) {
this.subClient.pUnsubscribe(
Expand Down Expand Up @@ -940,6 +962,9 @@ export class RedisAdapter extends Adapter {

this.pubClient.off("error", this.friendlyErrorHandler);
this.subClient.off("error", this.friendlyErrorHandler);

// Clear listener references
this.redisListeners.clear();
}
}

Expand Down
21 changes: 17 additions & 4 deletions lib/sharded-adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class ShardedRedisAdapter extends ClusterAdapter {
private readonly opts: Required<ShardedRedisAdapterOptions>;
private readonly channel: string;
private readonly responseChannel: string;
private readonly onCreateRoom: (room: string) => void;
private readonly onDeleteRoom: (room: string) => void;

constructor(nsp, pubClient, subClient, opts: ShardedRedisAdapterOptions) {
super(nsp);
Expand All @@ -97,17 +99,20 @@ class ShardedRedisAdapter extends ClusterAdapter {
this.opts.subscriptionMode === "dynamic" ||
this.opts.subscriptionMode === "dynamic-private"
) {
this.on("create-room", (room) => {
this.onCreateRoom = (room) => {
if (this.shouldUseASeparateNamespace(room)) {
SSUBSCRIBE(this.subClient, this.dynamicChannel(room), handler);
}
});
};

this.on("delete-room", (room) => {
this.onDeleteRoom = (room) => {
if (this.shouldUseASeparateNamespace(room)) {
SUNSUBSCRIBE(this.subClient, this.dynamicChannel(room));
}
});
};

this.on("create-room", this.onCreateRoom);
this.on("delete-room", this.onDeleteRoom);
}
}

Expand All @@ -118,6 +123,14 @@ class ShardedRedisAdapter extends ClusterAdapter {
this.opts.subscriptionMode === "dynamic" ||
this.opts.subscriptionMode === "dynamic-private"
) {
// Remove event listeners to prevent memory leaks
if (this.onCreateRoom) {
this.off("create-room", this.onCreateRoom);
}
if (this.onDeleteRoom) {
this.off("delete-room", this.onDeleteRoom);
}

this.rooms.forEach((_sids, room) => {
if (this.shouldUseASeparateNamespace(room)) {
channels.push(this.dynamicChannel(room));
Expand Down
81 changes: 61 additions & 20 deletions lib/util.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
export function parseTimeout(value: unknown, defaultValue: number): number {
return typeof value === "number" && value > 0 && value !== Infinity
? value
: defaultValue;
}

export function hasBinary(obj: any, toJSON?: boolean): boolean {
if (!obj || typeof obj !== "object") {
return false;
Expand Down Expand Up @@ -53,39 +59,74 @@ function isRedisV4Client(redisClient: any) {
}

const kHandlers = Symbol("handlers");
const kListener = Symbol("listener");
const kPendingUnsubscribes = Symbol("pendingUnsubscribes");

export function SSUBSCRIBE(
redisClient: any,
channel: string,
handler: (rawMessage: Buffer, channel: Buffer) => void
) {
): Promise<void> {
if (isRedisV4Client(redisClient)) {
redisClient.sSubscribe(channel, handler, RETURN_BUFFERS);
return redisClient.sSubscribe(channel, handler, RETURN_BUFFERS);
} else {
if (!redisClient[kHandlers]) {
redisClient[kHandlers] = new Map();
redisClient.on("smessageBuffer", (rawChannel, message) => {
redisClient[kHandlers].get(rawChannel.toString())?.(
message,
rawChannel
);
});
const doSubscribe = (): Promise<void> => {
if (!redisClient[kHandlers]) {
redisClient[kHandlers] = new Map<string, typeof handler>();
redisClient[kPendingUnsubscribes] = new Map<string, Promise<void>>();
redisClient[kListener] = (rawChannel: Buffer, message: Buffer) => {
redisClient[kHandlers].get(rawChannel.toString())?.(
message,
rawChannel
);
};
redisClient.on("smessageBuffer", redisClient[kListener]);
}
redisClient[kHandlers].set(channel, handler);
return redisClient.ssubscribe(channel);
};

// Wait for any pending unsubscribe on this channel to complete first
const pendingUnsubscribe = redisClient[kPendingUnsubscribes]?.get(channel);
if (pendingUnsubscribe) {
return pendingUnsubscribe.then(doSubscribe);
}
redisClient[kHandlers].set(channel, handler);
redisClient.ssubscribe(channel);
return doSubscribe();
}
}

export function SUNSUBSCRIBE(redisClient: any, channel: string | string[]) {
export function SUNSUBSCRIBE(
redisClient: any,
channel: string | string[]
): Promise<void> {
if (isRedisV4Client(redisClient)) {
redisClient.sUnsubscribe(channel);
return redisClient.sUnsubscribe(channel);
} else {
redisClient.sunsubscribe(channel);
if (Array.isArray(channel)) {
channel.forEach((c) => redisClient[kHandlers].delete(c));
} else {
redisClient[kHandlers].delete(channel);
}
const channels = Array.isArray(channel) ? channel : [channel];

// Remove handlers immediately to stop processing messages
channels.forEach((c) => redisClient[kHandlers]?.delete(c));

// Perform the unsubscribe and track as pending
const unsubscribePromise = redisClient.sunsubscribe(channel).then(() => {
// Remove from pending tracking
channels.forEach((c) => redisClient[kPendingUnsubscribes]?.delete(c));

// Clean up the global listener when no more handlers exist
if (redisClient[kHandlers]?.size === 0 && redisClient[kListener]) {
redisClient.off("smessageBuffer", redisClient[kListener]);
delete redisClient[kHandlers];
delete redisClient[kListener];
delete redisClient[kPendingUnsubscribes];
}
});

// Track pending unsubscribe for each channel
channels.forEach((c) =>
redisClient[kPendingUnsubscribes]?.set(c, unsubscribePromise)
);

return unsubscribePromise;
}
}

Expand Down
14 changes: 11 additions & 3 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@socket.io/redis-adapter",
"version": "8.3.0",
"version": "8.3.1",
"description": "The Socket.IO Redis adapter, allowing to broadcast events between several Socket.IO servers",
"license": "MIT",
"repository": {
Expand Down