@@ -140,7 +140,7 @@ def test__init__with_default_region(session_cls, mock_client_method, mock_bedroc
140
140
with unittest .mock .patch .object (os , "environ" , {}):
141
141
BedrockModel ()
142
142
session_cls .return_value .client .assert_called_with (
143
- region_name = DEFAULT_BEDROCK_REGION , config = ANY , service_name = ANY
143
+ region_name = DEFAULT_BEDROCK_REGION , config = ANY , service_name = ANY , endpoint_url = None
144
144
)
145
145
146
146
@@ -150,22 +150,22 @@ def test__init__with_session_region(session_cls, mock_client_method, mock_bedroc
150
150
151
151
BedrockModel ()
152
152
153
- mock_client_method .assert_called_with (region_name = "eu-blah-1" , config = ANY , service_name = ANY )
153
+ mock_client_method .assert_called_with (region_name = "eu-blah-1" , config = ANY , service_name = ANY , endpoint_url = None )
154
154
155
155
156
156
def test__init__with_custom_region (mock_client_method , mock_bedrock_inference_profiles ):
157
157
"""Test that BedrockModel uses the provided region."""
158
158
custom_region = "us-east-1"
159
159
BedrockModel (region_name = custom_region )
160
- mock_client_method .assert_called_with (region_name = custom_region , config = ANY , service_name = ANY )
160
+ mock_client_method .assert_called_with (region_name = custom_region , config = ANY , service_name = ANY , endpoint_url = None )
161
161
162
162
163
163
def test__init__with_default_environment_variable_region (mock_client_method , mock_bedrock_inference_profiles ):
164
164
"""Test that BedrockModel uses the AWS_REGION since we code that in."""
165
165
with unittest .mock .patch .object (os , "environ" , {"AWS_REGION" : "eu-west-2" }):
166
166
BedrockModel ()
167
167
168
- mock_client_method .assert_called_with (region_name = "eu-west-2" , config = ANY , service_name = ANY )
168
+ mock_client_method .assert_called_with (region_name = "eu-west-2" , config = ANY , service_name = ANY , endpoint_url = None )
169
169
170
170
171
171
def test__init__region_precedence (mock_client_method , session_cls , mock_bedrock_inference_profiles ):
@@ -175,21 +175,38 @@ def test__init__region_precedence(mock_client_method, session_cls, mock_bedrock_
175
175
176
176
# specifying a region always wins out
177
177
BedrockModel (region_name = "us-specified-1" )
178
- mock_client_method .assert_called_with (region_name = "us-specified-1" , config = ANY , service_name = ANY )
178
+ mock_client_method .assert_called_with (
179
+ region_name = "us-specified-1" , config = ANY , service_name = ANY , endpoint_url = None
180
+ )
179
181
180
182
# other-wise uses the session's
181
183
BedrockModel ()
182
- mock_client_method .assert_called_with (region_name = "us-session-1" , config = ANY , service_name = ANY )
184
+ mock_client_method .assert_called_with (
185
+ region_name = "us-session-1" , config = ANY , service_name = ANY , endpoint_url = None
186
+ )
183
187
184
188
# environment variable next
185
189
session_cls .return_value .region_name = None
186
190
BedrockModel ()
187
- mock_client_method .assert_called_with (region_name = "us-environment-1" , config = ANY , service_name = ANY )
191
+ mock_client_method .assert_called_with (
192
+ region_name = "us-environment-1" , config = ANY , service_name = ANY , endpoint_url = None
193
+ )
188
194
189
195
mock_os_environ .pop ("AWS_REGION" )
190
196
session_cls .return_value .region_name = None # No session region
191
197
BedrockModel ()
192
- mock_client_method .assert_called_with (region_name = DEFAULT_BEDROCK_REGION , config = ANY , service_name = ANY )
198
+ mock_client_method .assert_called_with (
199
+ region_name = DEFAULT_BEDROCK_REGION , config = ANY , service_name = ANY , endpoint_url = None
200
+ )
201
+
202
+
203
+ def test__init__with_endpoint_url (mock_client_method ):
204
+ """Test that BedrockModel uses the provided endpoint_url for VPC endpoints."""
205
+ custom_endpoint = "https://vpce-12345-abcde.bedrock-runtime.us-west-2.vpce.amazonaws.com"
206
+ BedrockModel (endpoint_url = custom_endpoint )
207
+ mock_client_method .assert_called_with (
208
+ region_name = DEFAULT_BEDROCK_REGION , config = ANY , service_name = ANY , endpoint_url = custom_endpoint
209
+ )
193
210
194
211
195
212
def test__init__with_region_and_session_raises_value_error ():
0 commit comments