Skip to content

Commit 0cff3dc

Browse files
Create new URI when replacing
1 parent aeb12f2 commit 0cff3dc

File tree

2 files changed

+159
-11
lines changed

2 files changed

+159
-11
lines changed

packages/smithy-http/src/smithy_http/aio/protocols.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from inspect import iscoroutinefunction
66
from typing import Any
77

8+
from smithy_core import URI as _URI
89
from smithy_core.aio.interfaces import AsyncByteStream, ClientProtocol
910
from smithy_core.aio.interfaces import StreamingBlob as AsyncStreamingBlob
1011
from smithy_core.codecs import Codec
@@ -39,17 +40,27 @@ def set_service_endpoint(
3940
endpoint: Endpoint,
4041
) -> HTTPRequest:
4142
uri = endpoint.uri
42-
uri_builder = request.destination
43-
44-
if uri.scheme:
45-
uri_builder.scheme = uri.scheme
46-
if uri.host:
47-
uri_builder.host = uri.host
48-
if uri.port and uri.port > -1:
49-
uri_builder.port = uri.port
50-
if uri.path:
51-
uri_builder.path = os.path.join(uri.path, uri_builder.path or "")
52-
# TODO: merge headers from the endpoint properties bag
43+
previous = request.destination
44+
45+
path = previous.path or uri.path
46+
if uri.path is not None and previous.path is not None:
47+
path = os.path.join(uri.path, previous.path.lstrip("/"))
48+
49+
query = previous.query or uri.query
50+
if uri.query is not None and previous.query is not None:
51+
query = f"{uri.query}&{previous.query}"
52+
53+
request.destination = _URI(
54+
scheme=uri.scheme,
55+
username=uri.username or previous.username,
56+
password=uri.password or previous.password,
57+
host=uri.host,
58+
port=uri.port or previous.port,
59+
path=path,
60+
query=query,
61+
fragment=uri.fragment or previous.fragment,
62+
)
63+
5364
return request
5465

5566

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from typing import Any
5+
6+
import pytest
7+
from smithy_core import URI
8+
from smithy_core.documents import TypeRegistry
9+
from smithy_core.endpoints import Endpoint
10+
from smithy_core.interfaces import TypedProperties
11+
from smithy_core.interfaces import URI as URIInterface
12+
from smithy_core.schemas import APIOperation
13+
from smithy_core.shapes import ShapeID
14+
from smithy_http import Fields
15+
from smithy_http.aio import HTTPRequest
16+
from smithy_http.aio.interfaces import HTTPRequest as HTTPRequestInterface
17+
from smithy_http.aio.interfaces import HTTPResponse as HTTPResponseInterface
18+
from smithy_http.aio.protocols import HttpClientProtocol
19+
20+
21+
class TestProtocol(HttpClientProtocol):
22+
_id = ShapeID("ns.foo#bar")
23+
24+
@property
25+
def id(self) -> ShapeID:
26+
return self._id
27+
28+
def serialize_request(
29+
self,
30+
*,
31+
operation: APIOperation[Any, Any],
32+
input: Any,
33+
endpoint: URIInterface,
34+
context: TypedProperties,
35+
) -> HTTPRequestInterface:
36+
raise Exception("This is only for tests.")
37+
38+
def deserialize_response(
39+
self,
40+
*,
41+
operation: APIOperation[Any, Any],
42+
request: HTTPRequestInterface,
43+
response: HTTPResponseInterface,
44+
error_registry: TypeRegistry,
45+
context: TypedProperties,
46+
) -> Any:
47+
raise Exception("This is only for tests.")
48+
49+
50+
@pytest.mark.parametrize(
51+
"request_uri,endpoint_uri,expected",
52+
[
53+
(
54+
URI(host="com.example", path="/foo"),
55+
URI(host="com.example", path="/bar"),
56+
URI(host="com.example", path="/bar/foo"),
57+
),
58+
(
59+
URI(host="com.example"),
60+
URI(host="com.example", path="/bar"),
61+
URI(host="com.example", path="/bar"),
62+
),
63+
(
64+
URI(host="com.example", path="/foo"),
65+
URI(host="com.example"),
66+
URI(host="com.example", path="/foo"),
67+
),
68+
(
69+
URI(host="com.example", scheme="http"),
70+
URI(host="com.example", scheme="https"),
71+
URI(host="com.example", scheme="https"),
72+
),
73+
(
74+
URI(host="com.example", username="name", password="password"),
75+
URI(host="com.example", username="othername", password="otherpassword"),
76+
URI(host="com.example", username="othername", password="otherpassword"),
77+
),
78+
(
79+
URI(host="com.example", username="name", password="password"),
80+
URI(host="com.example"),
81+
URI(host="com.example", username="name", password="password"),
82+
),
83+
(
84+
URI(host="com.example", port=8080),
85+
URI(host="com.example", port=8000),
86+
URI(host="com.example", port=8000),
87+
),
88+
(
89+
URI(host="com.example", port=8080),
90+
URI(host="com.example"),
91+
URI(host="com.example", port=8080),
92+
),
93+
(
94+
URI(host="com.example", query="foo=bar"),
95+
URI(host="com.example"),
96+
URI(host="com.example", query="foo=bar"),
97+
),
98+
(
99+
URI(host="com.example"),
100+
URI(host="com.example", query="spam"),
101+
URI(host="com.example", query="spam"),
102+
),
103+
(
104+
URI(host="com.example", query="foo=bar"),
105+
URI(host="com.example", query="spam"),
106+
URI(host="com.example", query="spam&foo=bar"),
107+
),
108+
(
109+
URI(host="com.example", fragment="header"),
110+
URI(host="com.example", fragment="footer"),
111+
URI(host="com.example", fragment="footer"),
112+
),
113+
(
114+
URI(host="com.example"),
115+
URI(host="com.example", fragment="footer"),
116+
URI(host="com.example", fragment="footer"),
117+
),
118+
(
119+
URI(host="com.example", fragment="header"),
120+
URI(host="com.example"),
121+
URI(host="com.example", fragment="header"),
122+
),
123+
],
124+
)
125+
def test_http_protocol_joins_uris(
126+
request_uri: URI, endpoint_uri: URI, expected: URI
127+
) -> None:
128+
protocol = TestProtocol()
129+
request = HTTPRequest(
130+
destination=request_uri,
131+
method="GET",
132+
fields=Fields(),
133+
)
134+
endpoint = Endpoint(uri=endpoint_uri)
135+
updated_request = protocol.set_service_endpoint(request=request, endpoint=endpoint)
136+
actual = updated_request.destination
137+
assert actual == expected

0 commit comments

Comments
 (0)