diff --git a/docs/settings.md b/docs/settings.md index e65c840..fafb6b1 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -57,6 +57,7 @@ In version `6.0.0` we've renamed the PG configuration variable to match the offi - `STAC_FASTAPI_LANDING_ID` (string) is a unique identifier for your Landing page - `ROOT_PATH`: set application root-path (when using proxy) - `CORS_ORIGINS`: A list of origins that should be permitted to make cross-origin requests. Defaults to `*` +- `CORS_ORIGIN_REGEX`: A regex string to match against origins that should be permitted to make cross-origin requests. eg. 'https://.*\.example\.org'. - `CORS_METHODS`: A list of HTTP methods that should be allowed for cross-origin requests. Defaults to `"GET,POST,OPTIONS"` - `CORS_CREDENTIALS`: Set to `true` to enable credentials via CORS requests. Note that you'll need to set `CORS_ORIGINS` to something other than `*`, because credentials are [disallowed](https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS/Errors/CORSNotSupportingCredentials) for wildcard CORS origins. - `CORS_HEADERS`: If `CORS_CREDENTIALS` are true and you're using an `Authorization` header, set this to `Content-Type,Authorization`. Alternatively, you can allow all headers by setting this to `*`. diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 2b0d175..0b1e831 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -185,6 +185,7 @@ async def lifespan(app: FastAPI): Middleware( CORSMiddleware, allow_origins=settings.cors_origins, + allow_origin_regex=settings.cors_origin_regex, allow_methods=settings.cors_methods, allow_credentials=settings.cors_credentials, allow_headers=settings.cors_headers, diff --git a/stac_fastapi/pgstac/config.py b/stac_fastapi/pgstac/config.py index 00432af..74fa271 100644 --- a/stac_fastapi/pgstac/config.py +++ b/stac_fastapi/pgstac/config.py @@ -1,12 +1,13 @@ """Postgres API configuration.""" import warnings -from typing import Annotated, Any, List, Optional, Type +from typing import Annotated, Any, List, Optional, Sequence, Type from urllib.parse import quote_plus as quote -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, BeforeValidator, Field, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict from stac_fastapi.types.config import ApiSettings +from typing_extensions import Self from stac_fastapi.pgstac.types.base_item_cache import ( BaseItemCache, @@ -155,6 +156,12 @@ def connection_string(self): return f"postgresql://{self.pguser}:{quote(self.pgpassword)}@{self.pghost}:{self.pgport}/{self.pgdatabase}" +def str_to_list(value: Any) -> Any: + if isinstance(value, str): + return [v.strip() for v in value.split(",")] + return value + + class Settings(ApiSettings): """API settings. @@ -177,24 +184,25 @@ class Settings(ApiSettings): Implies that the `Transactions` extension is enabled. """ - cors_origins: str = "*" - cors_methods: str = "GET,POST,OPTIONS" + cors_origins: Annotated[Sequence[str], BeforeValidator(str_to_list)] = ("*",) + cors_origin_regex: Optional[str] = None + cors_methods: Annotated[Sequence[str], BeforeValidator(str_to_list)] = ( + "GET", + "POST", + "OPTIONS", + ) cors_credentials: bool = False - cors_headers: str = "Content-Type" + cors_headers: Annotated[Sequence[str], BeforeValidator(str_to_list)] = ( + "Content-Type", + ) testing: bool = False - @field_validator("cors_origins") - def parse_cors_origin(cls, v): - """Parse CORS origins.""" - return [origin.strip() for origin in v.split(",")] - - @field_validator("cors_methods") - def parse_cors_methods(cls, v): - """Parse CORS methods.""" - return [method.strip() for method in v.split(",")] + @model_validator(mode="after") + def check_origins(self) -> Self: + if self.cors_origin_regex and "*" in self.cors_origins: + raise ValueError( + "Conflicting options found in API settings: `cors_origin_regex` and `*` in `cors_origins`" + ) - @field_validator("cors_headers") - def parse_cors_headers(cls, v): - """Parse CORS headers.""" - return [header.strip() for header in v.split(",")] + return self