Skip to content

Commit 2796d1c

Browse files
vincentsaragolukasbindreiter
authored andcommitted
use regex and simplify
1 parent 8983736 commit 2796d1c

File tree

2 files changed

+35
-26
lines changed

2 files changed

+35
-26
lines changed

stac_fastapi/api/stac_fastapi/api/middleware.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,16 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6868
proto == "https" and port != HTTPS_PORT
6969
):
7070
port_suffix = f":{port}"
71+
7172
scope["headers"] = self._replace_header_value_by_name(
7273
scope,
7374
"host",
7475
f"{domain}{port_suffix}",
7576
)
77+
7678
await self.app(scope, receive, send)
7779

78-
def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str]: # noqa: C901
80+
def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str]:
7981
proto = scope.get("scheme", "http")
8082
header_host = self._get_header_value_by_name(scope, "host")
8183
if header_host is None:
@@ -87,35 +89,28 @@ def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str]: # noqa: C901
8789
else:
8890
domain = header_host_parts[0]
8991
port = None
90-
forwarded = self._get_header_value_by_name(scope, "forwarded")
91-
if forwarded is not None:
92-
proxy_servers = forwarded.split(",") # values from the last server are used
93-
for proxy_server in proxy_servers:
94-
parts = proxy_server.split(";")
95-
for part in parts:
96-
if len(part) > 0 and re.search("=", part):
97-
key, value = part.split("=")
98-
if key == "proto":
99-
proto = value
100-
elif key == "host":
101-
host_parts = value.split(":")
102-
domain = host_parts[0]
103-
try:
104-
port = (
105-
int(host_parts[1]) if len(host_parts) == 2 else None
106-
)
107-
except ValueError:
108-
# ignore ports that are not valid integers
109-
pass
92+
93+
if forwarded := self._get_header_value_by_name(scope, "forwarded"):
94+
for proxy in forwarded.split(","):
95+
if (proto_expr := re.search(r"proto=(?P<proto>http(s)?)", proxy)) and (
96+
host_expr := re.search(
97+
r"host=(?P<host>[\w.-]+)(:(?P<port>\w+))?", proxy
98+
)
99+
):
100+
proto = proto_expr.groupdict()["proto"]
101+
domain = host_expr.groupdict()["host"]
102+
port_str = host_expr.groupdict().get("port", None)
103+
110104
else:
111105
domain = self._get_header_value_by_name(scope, "x-forwarded-host", domain)
112106
proto = self._get_header_value_by_name(scope, "x-forwarded-proto", proto)
113107
port_str = self._get_header_value_by_name(scope, "x-forwarded-port", port)
114-
try:
115-
port = int(port_str) if port_str is not None else None
116-
except ValueError:
117-
# ignore ports that are not valid integers
118-
pass
108+
109+
try:
110+
port = int(port_str) if port_str is not None else None
111+
except ValueError:
112+
# ignore ports that are not valid integers
113+
pass
119114

120115
return (proto, domain, port)
121116

stac_fastapi/api/tests/test_middleware.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,20 @@ def test_replace_header_value_by_name(
169169
},
170170
("https", "second-server", 1111),
171171
),
172+
(
173+
{
174+
"scheme": "http",
175+
"server": ["testserver", 80],
176+
"headers": [
177+
(
178+
b"forwarded",
179+
# check when host and port are inverted
180+
b"host=test:1234;proto=https",
181+
)
182+
],
183+
},
184+
("https", "test", 1234),
185+
),
172186
],
173187
)
174188
def test_get_forwarded_url_parts(

0 commit comments

Comments
 (0)