|
7 | 7 |
|
8 | 8 | import replicate |
9 | 9 | from replicate.client import Client |
10 | | -from replicate.exceptions import ReplicateError |
| 10 | +from replicate.exceptions import ModelError, ReplicateError |
11 | 11 |
|
12 | 12 |
|
13 | 13 | @pytest.mark.vcr("run.yaml") |
@@ -184,3 +184,72 @@ def prediction_with_status(status: str) -> dict: |
184 | 184 | ) |
185 | 185 |
|
186 | 186 | assert output == "Hello, world!" |
| 187 | + |
| 188 | + |
| 189 | +@pytest.mark.asyncio |
| 190 | +async def test_run_with_model_error(mock_replicate_api_token): |
| 191 | + def prediction_with_status(status: str) -> dict: |
| 192 | + return { |
| 193 | + "id": "p1", |
| 194 | + "model": "test/example", |
| 195 | + "version": "v1", |
| 196 | + "urls": { |
| 197 | + "get": "https://api.replicate.com/v1/predictions/p1", |
| 198 | + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", |
| 199 | + }, |
| 200 | + "created_at": "2023-10-05T12:00:00.000000Z", |
| 201 | + "source": "api", |
| 202 | + "status": status, |
| 203 | + "input": {"text": "world"}, |
| 204 | + "output": None, |
| 205 | + "error": "OOM" if status == "failed" else None, |
| 206 | + "logs": "", |
| 207 | + } |
| 208 | + |
| 209 | + router = respx.Router(base_url="https://api.replicate.com/v1") |
| 210 | + router.route(method="POST", path="/predictions").mock( |
| 211 | + return_value=httpx.Response( |
| 212 | + 201, |
| 213 | + json=prediction_with_status("processing"), |
| 214 | + ) |
| 215 | + ) |
| 216 | + router.route(method="GET", path="/predictions/p1").mock( |
| 217 | + return_value=httpx.Response( |
| 218 | + 200, |
| 219 | + json=prediction_with_status("failed"), |
| 220 | + ) |
| 221 | + ) |
| 222 | + router.route( |
| 223 | + method="GET", |
| 224 | + path="/models/test/example/versions/v1", |
| 225 | + ).mock( |
| 226 | + return_value=httpx.Response( |
| 227 | + 201, |
| 228 | + json={ |
| 229 | + "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", |
| 230 | + "created_at": "2024-07-18T00:35:56.210272Z", |
| 231 | + "cog_version": "0.9.10", |
| 232 | + "openapi_schema": { |
| 233 | + "openapi": "3.0.2", |
| 234 | + }, |
| 235 | + }, |
| 236 | + ) |
| 237 | + ) |
| 238 | + router.route(host="api.replicate.com").pass_through() |
| 239 | + |
| 240 | + client = Client( |
| 241 | + api_token="test-token", transport=httpx.MockTransport(router.handler) |
| 242 | + ) |
| 243 | + client.poll_interval = 0.001 |
| 244 | + |
| 245 | + with pytest.raises(ModelError) as excinfo: |
| 246 | + client.run( |
| 247 | + "test/example:v1", |
| 248 | + input={ |
| 249 | + "text": "Hello, world!", |
| 250 | + }, |
| 251 | + ) |
| 252 | + |
| 253 | + assert str(excinfo.value) == "OOM" |
| 254 | + assert excinfo.value.prediction.error == "OOM" |
| 255 | + assert excinfo.value.prediction.status == "failed" |
0 commit comments