33import pytest
44
55import replicate
6+ from replicate .exceptions import ReplicateError
67from replicate .stream import ServerSentEvent
78
89skip_if_no_token = pytest .mark .skipif (
@@ -21,22 +22,28 @@ async def test_stream(async_flag, record_mode):
2122
2223 events = []
2324
24- if async_flag :
25- async for event in await replicate .async_stream (
26- model ,
27- input = input ,
28- ):
29- events .append (event )
30- else :
31- for event in replicate .stream (
32- model ,
33- input = input ,
34- ):
35- events .append (event )
25+ try :
26+ if async_flag :
27+ async for event in await replicate .async_stream (
28+ model ,
29+ input = input ,
30+ ):
31+ events .append (event )
32+ else :
33+ for event in replicate .stream (
34+ model ,
35+ input = input ,
36+ ):
37+ events .append (event )
3638
37- assert len (events ) > 0
38- assert any (event .event == ServerSentEvent .EventType .OUTPUT for event in events )
39- assert any (event .event == ServerSentEvent .EventType .DONE for event in events )
39+ assert len (events ) > 0
40+ assert any (event .event == ServerSentEvent .EventType .OUTPUT for event in events )
41+ assert any (event .event == ServerSentEvent .EventType .DONE for event in events )
42+ except ReplicateError as e :
43+ if e .status == 401 :
44+ pytest .skip ("Skipping test due to authentication error" )
45+ else :
46+ raise e
4047
4148
4249@skip_if_no_token
@@ -50,15 +57,21 @@ async def test_stream_prediction(async_flag, record_mode):
5057
5158 events = []
5259
53- if async_flag :
54- async for event in replicate .predictions .create (
55- version = version , input = input , stream = True
56- ).async_stream ():
57- events .append (event )
58- else :
59- for event in replicate .predictions .create (
60- version = version , input = input , stream = True
61- ).stream ():
62- events .append (event )
60+ try :
61+ if async_flag :
62+ async for event in replicate .predictions .create (
63+ version = version , input = input , stream = True
64+ ).async_stream ():
65+ events .append (event )
66+ else :
67+ for event in replicate .predictions .create (
68+ version = version , input = input , stream = True
69+ ).stream ():
70+ events .append (event )
6371
64- assert len (events ) > 0
72+ assert len (events ) > 0
73+ except ReplicateError as e :
74+ if e .status == 401 :
75+ pytest .skip ("Skipping test due to authentication error" )
76+ else :
77+ raise e
0 commit comments