Skip to content

Commit 924ab59

Browse files
committed
fix: add type annotations to test_stream.py for linter
1 parent 58ac19f commit 924ab59

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

tests/lib/test_stream.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
from __future__ import annotations
22

33
import os
4-
from typing import Iterator
4+
from typing import Any, Iterator
55

66
import httpx
77
import pytest
8+
from respx import MockRouter
89

910
from replicate import Replicate, AsyncReplicate
1011

1112
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
1213
bearer_token = "My Bearer Token"
1314

1415

15-
def create_mock_prediction_json(stream_url: str | None = None) -> dict:
16+
def create_mock_prediction_json(stream_url: str | None = None) -> dict[str, Any]:
1617
"""Helper to create a complete prediction JSON response"""
1718
prediction = {
1819
"id": "test-prediction-id",
@@ -34,7 +35,7 @@ def create_mock_prediction_json(stream_url: str | None = None) -> dict:
3435
return prediction
3536

3637

37-
def test_stream_with_model_owner_name(respx_mock) -> None:
38+
def test_stream_with_model_owner_name(respx_mock: MockRouter) -> None:
3839
"""Test streaming with owner/name format"""
3940
client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)
4041

@@ -71,7 +72,7 @@ def stream_content() -> Iterator[bytes]:
7172
assert output == ["Hello", " world", "!"]
7273

7374

74-
def test_stream_with_version_id(respx_mock) -> None:
75+
def test_stream_with_version_id(respx_mock: MockRouter) -> None:
7576
"""Test streaming with version ID"""
7677
client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)
7778
version_id = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa"
@@ -108,7 +109,7 @@ def stream_content() -> Iterator[bytes]:
108109
assert output == ["Test", "output"]
109110

110111

111-
def test_stream_no_stream_url_raises_error(respx_mock) -> None:
112+
def test_stream_no_stream_url_raises_error(respx_mock: MockRouter) -> None:
112113
"""Test that streaming raises an error when model doesn't support streaming"""
113114
client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)
114115

@@ -127,7 +128,7 @@ def test_stream_no_stream_url_raises_error(respx_mock) -> None:
127128

128129

129130
@pytest.mark.asyncio
130-
async def test_async_stream_with_model_owner_name(respx_mock) -> None:
131+
async def test_async_stream_with_model_owner_name(respx_mock: MockRouter) -> None:
131132
"""Test async streaming with owner/name format"""
132133
async_client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)
133134

@@ -164,7 +165,7 @@ async def stream_content():
164165

165166

166167
@pytest.mark.asyncio
167-
async def test_async_stream_no_stream_url_raises_error(respx_mock) -> None:
168+
async def test_async_stream_no_stream_url_raises_error(respx_mock: MockRouter) -> None:
168169
"""Test that async streaming raises an error when model doesn't support streaming"""
169170
async_client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)
170171

@@ -182,7 +183,7 @@ async def test_async_stream_no_stream_url_raises_error(respx_mock) -> None:
182183
pass
183184

184185

185-
def test_stream_module_level(respx_mock) -> None:
186+
def test_stream_module_level(respx_mock: MockRouter) -> None:
186187
"""Test that module-level stream function works"""
187188
import replicate
188189

0 commit comments

Comments
 (0)