Skip to content

Commit fd6e111

Browse files
authored
add polling transport option (#5955)
* add polling transport option * fix CORS * set number of workers to 1 * add enterprise guard * add once why not
1 parent 608f71e commit fd6e111

File tree

6 files changed

+34
-7
lines changed

6 files changed

+34
-7
lines changed

reflex/.templates/web/utils/state.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ export const useEventLoop = (
854854
await connect(
855855
socket,
856856
dispatch,
857-
["websocket"],
857+
[env.TRANSPORT],
858858
setConnectErrors,
859859
client_storage,
860860
navigate,

reflex/app.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -529,22 +529,40 @@ def _setup_state(self) -> None:
529529

530530
# Set up the Socket.IO AsyncServer.
531531
if not self.sio:
532+
if (
533+
config.transport == "polling"
534+
and (tier := prerequisites.get_user_tier()) != "enterprise"
535+
):
536+
console.error(
537+
"The 'polling' transport is only available for Enterprise users. "
538+
+ (
539+
"Please upgrade your plan to use this feature."
540+
if tier != "anonymous"
541+
else "Please log in with `reflex login` to use this feature."
542+
)
543+
)
544+
raise SystemExit(1)
532545
self.sio = AsyncServer(
533546
async_mode="asgi",
534547
cors_allowed_origins=(
535-
"*"
536-
if config.cors_allowed_origins == ("*",)
537-
else list(config.cors_allowed_origins)
548+
(
549+
"*"
550+
if config.cors_allowed_origins == ("*",)
551+
else list(config.cors_allowed_origins)
552+
)
553+
if config.transport == "websocket"
554+
else []
538555
),
539-
cors_credentials=True,
556+
cors_credentials=config.transport == "websocket",
540557
max_http_buffer_size=environment.REFLEX_SOCKET_MAX_HTTP_BUFFER_SIZE.get(),
541558
ping_interval=environment.REFLEX_SOCKET_INTERVAL.get(),
542559
ping_timeout=environment.REFLEX_SOCKET_TIMEOUT.get(),
543560
json=SimpleNamespace(
544561
dumps=staticmethod(format.json_dumps),
545562
loads=staticmethod(json.loads),
546563
),
547-
transports=["websocket"],
564+
allow_upgrades=False,
565+
transports=[config.transport],
548566
)
549567
elif getattr(self.sio, "async_mode", "") != "asgi":
550568
msg = f"Custom `sio` must use `async_mode='asgi'`, not '{self.sio.async_mode}'."

reflex/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from importlib.util import find_spec
1111
from pathlib import Path
1212
from types import ModuleType
13-
from typing import TYPE_CHECKING, Any, ClassVar
13+
from typing import TYPE_CHECKING, Any, ClassVar, Literal
1414

1515
from reflex import constants
1616
from reflex.constants.base import LogLevel
@@ -254,6 +254,9 @@ class BaseConfig:
254254
# List of fully qualified import paths of plugins to disable in the app (e.g. reflex.plugins.sitemap.SitemapPlugin).
255255
disable_plugins: list[str] = dataclasses.field(default_factory=list)
256256

257+
# The transport method for client-server communication.
258+
transport: Literal["websocket", "polling"] = "websocket"
259+
257260
# Whether to skip plugin checks.
258261
_skip_plugins_checks: bool = dataclasses.field(default=False, repr=False)
259262

reflex/utils/build.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def set_env_json():
2020
str(prerequisites.get_web_dir() / constants.Dirs.ENV_JSON),
2121
{
2222
**{endpoint.name: endpoint.get_url() for endpoint in constants.Endpoint},
23+
"TRANSPORT": get_config().transport,
2324
"TEST_MODE": is_in_app_harness(),
2425
},
2526
)

reflex/utils/prerequisites.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,7 @@ def check_schema_up_to_date():
674674
)
675675

676676

677+
@once
677678
def get_user_tier():
678679
"""Get the current user's tier.
679680

reflex/utils/processes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from rich.progress import Progress
2020

2121
from reflex import constants
22+
from reflex.config import get_config
2223
from reflex.environment import environment
2324
from reflex.utils import console, path_ops, prerequisites
2425
from reflex.utils.registry import get_npm_registry
@@ -42,6 +43,9 @@ def get_num_workers() -> int:
4243
Returns:
4344
The number of backend worker processes.
4445
"""
46+
if get_config().transport == "polling":
47+
return 1
48+
4549
if (redis_client := prerequisites.get_redis_sync()) is None:
4650
return 1
4751

0 commit comments

Comments
 (0)