@@ -6,7 +6,7 @@ const cluster = require('cluster');
66const redis = require ( './broker/redis' ) ;
77const Client = require ( './client' ) ;
88const { checkToken, readJson, setCorsHeaders } = require ( "./utils/guard" ) ;
9- const { getCPULoadAVG, getLocalIp, getVersionNum, ensureDirectory } = require ( "./utils/tool" ) ;
9+ const { getCPULoadAVG, getLocalIp, getVersionNum, ensureDirectory, getDomainFromOrigin } = require ( "./utils/tool" ) ;
1010// const cpuOverload = new (require('./utils/cpuOverload'))(10, 80, 0.8);
1111const decoder = new StringDecoder ( ) ;
1212const inspector = require ( "node:inspector" ) ;
@@ -16,7 +16,7 @@ const versionNumber = getVersionNum(version);
1616const COMPACT_VERSION = "1" ;
1717
1818class Server {
19- constructor ( hub , logger , endPoint , stats , ratelimit , security , compression , compact , extra = { } ) {
19+ constructor ( hub , logger , endPoint , stats , ratelimit , security , compression , compact , access , extra = { } ) {
2020 this . logger = logger ;
2121 this . extra = extra ;
2222 this . hub = hub ;
@@ -33,19 +33,64 @@ class Server {
3333 this . limiter = new RateLimiter ( { tokensPerInterval : ratelimit . max_rate , interval : "second" } ) ;
3434 }
3535 this . security = security || { } ;
36+ this . access = access || { } ;
37+ this . validateOrigin = false ;
3638 this . stats = stats || { } ;
3739 this . compression = compression || { } ;
3840 this . compact = compact || { } ;
3941 this . internalIp = getLocalIp ( ) ;
4042 // cpuOverload.check().then().catch(err => {
4143 // logger.error(err)
4244 // });
45+ this . validateAccess ( ) ;
4346 this . app = this . isSSL ? SSLApp ( {
4447 key_file_name : endPoint . key ,
4548 cert_file_name : endPoint . cert ,
4649 } ) : App ( ) ;
4750 }
4851
52+ validateAccess ( ) {
53+ if ( this . access . allowDomains !== undefined ) {
54+ if ( this . access . denyDomains !== undefined ) {
55+ throw new Error (
56+ "allowDomains and denyDomains can't be set simultaneously" ,
57+ ) ;
58+ } else if ( ! ( this . access . allowDomains instanceof Array ) ) {
59+ throw new Error (
60+ "allowDomains configuration parameters should be an array of strings" ,
61+ ) ;
62+ }
63+ } else if (
64+ this . access . denyDomains !== undefined &&
65+ ! ( this . access . denyDomains instanceof Array )
66+ ) {
67+ throw new Error (
68+ "denyDomains configuration parameters should be an array of strings" ,
69+ ) ;
70+ }
71+
72+ const origins = this . access . allowDomains ?? this . access . denyDomains ;
73+
74+ if ( origins !== undefined ) {
75+ for ( const origin of origins ) {
76+ if ( typeof origin !== "string" ) {
77+ throw new Error (
78+ "allowDomains and denyDomains configuration parameters should be arrays of strings" ,
79+ ) ;
80+ }
81+ }
82+ }
83+
84+ if ( this . access . allowDomains ) {
85+ this . access . allowDomains = new Set ( this . access . allowDomains ) ;
86+ this . validateOrigin = true
87+ }
88+ if ( this . access . denyDomains ) {
89+ this . access . denyDomains = new Set ( this . access . denyDomains ) ;
90+ this . validateOrigin = true
91+ }
92+ }
93+
4994 buildServer ( ) {
5095 this . app
5196 . get ( '/health' , ( response ) => {
@@ -263,6 +308,7 @@ class Server {
263308 const secWebSocketKey = req . getHeader ( "sec-websocket-key" ) ;
264309 const secWebSocketProtocol = req . getHeader ( "sec-websocket-protocol" ) ;
265310 const secWebSocketExtensions = req . getHeader ( "sec-websocket-extensions" ) ;
311+ const origin = req . getHeader ( "origin" ) ;
266312 const id = req . getQuery ( "id" ) ;
267313 const token = req . getQuery ( "token" ) ;
268314 const compactVersion = req . getQuery ( "c" ) ;
@@ -279,7 +325,7 @@ class Server {
279325 return
280326 }
281327 res . upgrade (
282- { id, token, compactVersion, device, batchable } ,
328+ { id, token, compactVersion, device, batchable, origin } ,
283329 secWebSocketKey ,
284330 secWebSocketProtocol ,
285331 secWebSocketExtensions ,
@@ -288,7 +334,7 @@ class Server {
288334 }
289335
290336 _onOpen ( ws ) {
291- const { id, token, compactVersion, device, batchable } = ws . getUserData ( ) ;
337+ const { id, token, compactVersion, device, batchable, origin } = ws . getUserData ( ) ;
292338 if ( ! id || id . length < 6 ) {
293339 ws . end ( 4000 , 'id is not valid' ) ;
294340 return
@@ -297,6 +343,17 @@ class Server {
297343 ws . end ( 4000 , 'token is not valid' ) ;
298344 return
299345 }
346+ if ( this . validateOrigin ) {
347+ const domain = getDomainFromOrigin ( origin ) ;
348+ const shouldDeny = ! ! domain &&
349+ ( this . access . denyDomains ?. has ( domain ) ||
350+ this . access . allowDomains ?. has ( domain ) === false ) ;
351+ if ( shouldDeny ) {
352+ this . logger . warn ( `denied domain ${ domain } ` ) ;
353+ ws . end ( 4000 , 'domain is denied' ) ;
354+ return
355+ }
356+ }
300357 let client = this . hub . getClient ( id ) ;
301358 if ( client ) {
302359 // this.logger.info(`${id} is already exist`);
0 commit comments