|
5 | 5 |
|
6 | 6 | from typing_extensions import NotRequired, TypedDict, Unpack |
7 | 7 |
|
8 | | -from replicate.exceptions import ModelError |
| 8 | +from replicate.exceptions import ModelError, ReplicateError |
9 | 9 | from replicate.files import upload_file |
10 | 10 | from replicate.json import encode_json |
11 | 11 | from replicate.pagination import Page |
12 | 12 | from replicate.resource import Namespace, Resource |
| 13 | +from replicate.stream import EventSource |
13 | 14 | from replicate.version import Version |
14 | 15 |
|
15 | 16 | try: |
|
19 | 20 |
|
20 | 21 | if TYPE_CHECKING: |
21 | 22 | from replicate.client import Client |
| 23 | + from replicate.stream import ServerSentEvent |
22 | 24 |
|
23 | 25 |
|
24 | 26 | class Prediction(Resource): |
@@ -125,6 +127,25 @@ def wait(self) -> None: |
125 | 127 | time.sleep(self._client.poll_interval) |
126 | 128 | self.reload() |
127 | 129 |
|
| 130 | + def stream(self) -> Optional[Iterator["ServerSentEvent"]]: |
| 131 | + """ |
| 132 | + Stream the prediction output. |
| 133 | +
|
| 134 | + Raises: |
| 135 | + ReplicateError: If the model does not support streaming. |
| 136 | + """ |
| 137 | + |
| 138 | + url = self.urls and self.urls.get("stream", None) |
| 139 | + if not url or not isinstance(url, str): |
| 140 | + raise ReplicateError("Model does not support streaming") |
| 141 | + |
| 142 | + headers = {} |
| 143 | + headers["Accept"] = "text/event-stream" |
| 144 | + headers["Cache-Control"] = "no-store" |
| 145 | + |
| 146 | + with self._client._client.stream("GET", url, headers=headers) as response: |
| 147 | + yield from EventSource(response) |
| 148 | + |
128 | 149 | def cancel(self) -> None: |
129 | 150 | """ |
130 | 151 | Cancels a running prediction. |
|
0 commit comments