diff --git a/CHANGES.md b/CHANGES.md index 1f4f1568..d3a0db6f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Fixed + +- fix links when app is mounted behind proxy or has router-prefix ([#195](https://github.com/stac-utils/stac-fastapi-pgstac/pull/195)) + ## [4.0.2] - 2025-02-18 ### Fixed diff --git a/setup.py b/setup.py index 702c420c..1e5837d5 100644 --- a/setup.py +++ b/setup.py @@ -10,9 +10,9 @@ "orjson", "pydantic", "stac_pydantic==3.1.*", - "stac-fastapi.api~=5.0", - "stac-fastapi.extensions~=5.0", - "stac-fastapi.types~=5.0", + "stac-fastapi.api>=5.0,<5.1", + "stac-fastapi.extensions>=5.0,<5.1", + "stac-fastapi.types>=5.0,<5.1", "asyncpg", "buildpg", "brotli_asgi", diff --git a/stac_fastapi/pgstac/models/links.py b/stac_fastapi/pgstac/models/links.py index 819e39e5..6697488b 100644 --- a/stac_fastapi/pgstac/models/links.py +++ b/stac_fastapi/pgstac/models/links.py @@ -51,7 +51,11 @@ def base_url(self): @property def url(self): """Get the current request url.""" - return str(self.request.url) + url = urljoin(str(self.request.base_url), self.request.url.path.lstrip("/")) + if qs := self.request.url.query: + url += f"?{qs}" + + return url def resolve(self, url): """Resolve url to the current request url.""" @@ -143,7 +147,7 @@ def link_next(self) -> Optional[Dict[str, Any]]: "rel": Relations.next.value, "type": MimeTypes.geojson.value, "method": method, - "href": f"{self.request.url}", + "href": self.url, "body": {**self.request.postbody, "token": f"next:{self.next}"}, } @@ -167,7 +171,7 @@ def link_prev(self) -> Optional[Dict[str, Any]]: "rel": Relations.previous.value, "type": MimeTypes.geojson.value, "method": method, - "href": f"{self.request.url}", + "href": self.url, "body": {**self.request.postbody, "token": f"prev:{self.prev}"}, } return None diff --git a/tests/api/test_links.py b/tests/api/test_links.py new file mode 100644 index 00000000..e8e57a95 --- /dev/null +++ b/tests/api/test_links.py @@ -0,0 +1,106 @@ +import pytest +from fastapi import APIRouter, FastAPI +from starlette.requests import Request +from starlette.testclient import TestClient + +from stac_fastapi.pgstac.models import links as app_links + + +@pytest.mark.parametrize("root_path", ["", "/api/v1"]) +@pytest.mark.parametrize("prefix", ["", "/stac"]) +def tests_app_links(prefix, root_path): # noqa: C901 + endpoint_prefix = root_path + prefix + url_prefix = "http://stac.io" + endpoint_prefix + + app = FastAPI(root_path=root_path) + router = APIRouter(prefix=prefix) + app.state.router_prefix = router.prefix + + @router.get("/search") + @router.post("/search") + async def search(request: Request): + links = app_links.PagingLinks(request, next="yo:2", prev="yo:1") + return { + "url": links.url, + "base_url": links.base_url, + "links": await links.get_links(), + } + + @router.get("/collections") + async def collections(request: Request): + pgstac_next = { + "rel": "next", + "body": {"offset": 1}, + "href": "./collections", + "type": "application/json", + "merge": True, + "method": "GET", + } + pgstac_prev = { + "rel": "prev", + "body": {"offset": 0}, + "href": "./collections", + "type": "application/json", + "merge": True, + "method": "GET", + } + links = app_links.CollectionSearchPagingLinks( + request, next=pgstac_next, prev=pgstac_prev + ) + return { + "url": links.url, + "base_url": links.base_url, + "links": await links.get_links(), + } + + app.include_router(router) + + with TestClient( + app, + base_url="http://stac.io", + root_path=root_path, + ) as client: + response = client.get(f"{prefix}/search") + assert response.status_code == 200 + assert response.json()["url"] == url_prefix + "/search" + assert response.json()["base_url"].rstrip("/") == url_prefix + links = response.json()["links"] + for link in links: + if link["rel"] in ["previous", "next"]: + assert link["method"] == "GET" + assert link["href"].startswith(url_prefix) + assert {"next", "previous", "root", "self"} == {link["rel"] for link in links} + + response = client.get(f"{prefix}/search", params={"limit": 1}) + assert response.status_code == 200 + assert response.json()["url"] == url_prefix + "/search?limit=1" + assert response.json()["base_url"].rstrip("/") == url_prefix + links = response.json()["links"] + for link in links: + if link["rel"] in ["previous", "next"]: + assert link["method"] == "GET" + assert "limit=1" in link["href"] + assert link["href"].startswith(url_prefix) + assert {"next", "previous", "root", "self"} == {link["rel"] for link in links} + + response = client.post(f"{prefix}/search", json={}) + assert response.status_code == 200 + assert response.json()["url"] == url_prefix + "/search" + assert response.json()["base_url"].rstrip("/") == url_prefix + links = response.json()["links"] + for link in links: + if link["rel"] in ["previous", "next"]: + assert link["method"] == "POST" + assert link["href"].startswith(url_prefix) + assert {"next", "previous", "root", "self"} == {link["rel"] for link in links} + + response = client.get(f"{prefix}/collections") + assert response.status_code == 200 + assert response.json()["url"] == url_prefix + "/collections" + assert response.json()["base_url"].rstrip("/") == url_prefix + links = response.json()["links"] + for link in links: + if link["rel"] in ["previous", "next"]: + assert link["method"] == "GET" + assert link["href"].startswith(url_prefix) + assert {"next", "previous", "root", "self"} == {link["rel"] for link in links} diff --git a/tests/conftest.py b/tests/conftest.py index ce456534..ec411699 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,3 @@ -import asyncio import json import logging import os @@ -62,11 +61,6 @@ ) -@pytest.fixture(scope="session") -def event_loop(): - return asyncio.get_event_loop() - - @pytest.fixture(scope="session") def database(postgresql_proc): with DatabaseJanitor(