Skip to content

Commit 0dc073d

Browse files
committed
Split REFLEX_CORS_ALLOWED_ORIGINS by comma
Default delimiter is colon with no stripping of whitespace. However for CORS it's much more convenient to split on a comma. Extend the env_var interpretation system to pull a new SequenceOptions object out of an Annotated type hint. Fix #6066
1 parent 5b3aa27 commit 0dc073d

File tree

4 files changed

+75
-4
lines changed

4 files changed

+75
-4
lines changed

reflex/config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010
from importlib.util import find_spec
1111
from pathlib import Path
1212
from types import ModuleType
13-
from typing import TYPE_CHECKING, Any, ClassVar, Literal
13+
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal
1414

1515
from reflex import constants
1616
from reflex.constants.base import LogLevel
1717
from reflex.environment import EnvironmentVariables as EnvironmentVariables
1818
from reflex.environment import EnvVar as EnvVar
1919
from reflex.environment import (
2020
ExistingPath,
21+
SequenceOptions,
2122
_load_dotenv_from_files,
2223
_paths_from_env_files,
2324
interpret_env_var_value,
@@ -207,8 +208,10 @@ class BaseConfig:
207208
# Timeout to do a production build of a frontend page.
208209
static_page_generation_timeout: int = 60
209210

210-
# List of origins that are allowed to connect to the backend API.
211-
cors_allowed_origins: Sequence[str] = dataclasses.field(default=("*",))
211+
# Comma separated list of origins that are allowed to connect to the backend API.
212+
cors_allowed_origins: Annotated[Sequence[str], SequenceOptions(delimiter=",")] = (
213+
dataclasses.field(default=("*",))
214+
)
212215

213216
# Whether to use React strict mode.
214217
react_strict_mode: bool = True

reflex/environment.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,17 @@ def interpret_enum_env(value: str, field_type: GenericType, field_name: str) ->
212212
raise EnvironmentVarValueError(msg) from ve
213213

214214

215+
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
216+
class SequenceOptions:
217+
"""Options for interpreting Sequence environment variables."""
218+
219+
delimiter: str = ":"
220+
strip: bool = False
221+
222+
223+
DEFAULT_SEQUENCE_OPTIONS = SequenceOptions()
224+
225+
215226
def interpret_env_var_value(
216227
value: str, field_type: GenericType, field_name: str
217228
) -> Any:
@@ -278,14 +289,26 @@ def interpret_env_var_value(
278289
continue
279290
msg = f"Invalid literal value: {value!r} for {field_name}, expected one of {literal_values}"
280291
raise EnvironmentVarValueError(msg)
292+
# If the field is Annotated with SequenceOptions, extract the options
293+
sequence_options = DEFAULT_SEQUENCE_OPTIONS
294+
if get_origin(field_type) is Annotated:
295+
annotated_args = get_args(field_type)
296+
field_type = annotated_args[0]
297+
for arg in annotated_args[1:]:
298+
if isinstance(arg, SequenceOptions):
299+
sequence_options = arg
300+
break
281301
if get_origin(field_type) in (list, Sequence):
302+
items = value.split(sequence_options.delimiter)
303+
if sequence_options.strip:
304+
items = [item.strip() for item in items]
282305
return [
283306
interpret_env_var_value(
284307
v,
285308
get_args(field_type)[0],
286309
f"{field_name}[{i}]",
287310
)
288-
for i, v in enumerate(value.split(":"))
311+
for i, v in enumerate(items)
289312
]
290313
if isinstance(field_type, type) and issubclass(field_type, enum.Enum):
291314
return interpret_enum_env(value, field_type, field_name)

tests/units/test_config.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,41 @@ def test_update_from_env_path(
9696
assert config.bun_path == tmp_path
9797

9898

99+
def test_update_from_env_cors(
100+
base_config_values: dict[str, Any],
101+
monkeypatch: pytest.MonkeyPatch,
102+
tmp_path: Path,
103+
):
104+
"""Test that environment variables override config values.
105+
106+
Args:
107+
base_config_values: Config values.
108+
monkeypatch: The pytest monkeypatch object.
109+
tmp_path: The pytest tmp_path fixture object.
110+
"""
111+
config = rx.Config(**base_config_values)
112+
assert config.cors_allowed_origins == ("*",)
113+
114+
monkeypatch.setenv("REFLEX_CORS_ALLOWED_ORIGINS", "")
115+
config = rx.Config(**base_config_values)
116+
assert config.cors_allowed_origins == ("*",)
117+
118+
monkeypatch.setenv("REFLEX_CORS_ALLOWED_ORIGINS", "https://foo.example.com")
119+
config = rx.Config(**base_config_values)
120+
assert config.cors_allowed_origins == [
121+
"https://foo.example.com",
122+
]
123+
124+
monkeypatch.setenv(
125+
"REFLEX_CORS_ALLOWED_ORIGINS", "http://example.com, http://another.com "
126+
)
127+
config = rx.Config(**base_config_values)
128+
assert config.cors_allowed_origins == [
129+
"http://example.com",
130+
"http://another.com",
131+
]
132+
133+
99134
@pytest.mark.parametrize(
100135
("kwargs", "expected"),
101136
[

tests/units/test_environment.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import tempfile
66
from pathlib import Path
7+
from typing import Annotated
78
from unittest.mock import patch
89

910
import pytest
@@ -15,6 +16,7 @@
1516
ExecutorType,
1617
ExistingPath,
1718
PerformanceMode,
19+
SequenceOptions,
1820
_load_dotenv_from_files,
1921
_paths_from_env_files,
2022
_paths_from_environment,
@@ -175,6 +177,14 @@ def test_interpret_list(self):
175177
result = interpret_env_var_value("1:2:3", list[int], "TEST_FIELD")
176178
assert result == [1, 2, 3]
177179

180+
def test_interpret_annotated_sequence(self):
181+
"""Test annotated sequence interpretation."""
182+
annotated_type = Annotated[
183+
list[str], SequenceOptions(delimiter=",", strip=True)
184+
]
185+
result = interpret_env_var_value("a, b, c ", annotated_type, "TEST_FIELD")
186+
assert result == ["a", "b", "c"]
187+
178188
def test_interpret_enum(self):
179189
"""Test enum interpretation."""
180190
result = interpret_env_var_value("value1", _TestEnum, "TEST_FIELD")

0 commit comments

Comments
 (0)