@@ -48,10 +48,16 @@ export default class Model extends ReplicateObject {
4848
4949 async predict (
5050 input ,
51- { onUpdate = noop , onTemporaryError = noop } = { } ,
51+ {
52+ onUpdate = noop ,
53+ onTemporaryError = noop ,
54+ onCancel = noop ,
55+ onCancelError = noop ,
56+ } = { } ,
5257 {
5358 defaultPollingInterval = 500 ,
5459 backoffFn = ( errorCount ) => Math . pow ( 2 , errorCount ) * 100 ,
60+ cancelOnFatalError = false ,
5561 } = { }
5662 ) {
5763 if ( ! input ) {
@@ -60,39 +66,57 @@ export default class Model extends ReplicateObject {
6066
6167 let prediction = await this . createPrediction ( input ) ;
6268
63- onUpdate ( prediction ) ;
69+ try {
70+ onUpdate ( prediction ) ;
6471
65- let pollingInterval = defaultPollingInterval ;
66- let errorCount = 0 ;
72+ let pollingInterval = defaultPollingInterval ;
73+ let errorCount = 0 ;
6774
68- while ( ! prediction . hasTerminalStatus ( ) ) {
69- await sleep ( pollingInterval ) ;
70- pollingInterval = defaultPollingInterval ; // Reset to default each time.
75+ while ( ! prediction . hasTerminalStatus ( ) ) {
76+ await sleep ( pollingInterval ) ;
77+ pollingInterval = defaultPollingInterval ; // Reset to default each time.
7178
72- try {
73- prediction = await this . client . prediction ( prediction . id ) . load ( ) ;
79+ try {
80+ prediction = await this . client . prediction ( prediction . id ) . load ( ) ;
7481
75- onUpdate ( prediction ) ;
82+ onUpdate ( prediction ) ;
7683
77- errorCount = 0 ; // Reset because we've had a non-error response.
78- } catch ( err ) {
79- if ( ! err instanceof ReplicateResponseError ) {
80- throw err ;
81- }
84+ errorCount = 0 ; // Reset because we've had a non-error response.
85+ } catch ( err ) {
86+ if ( ! err instanceof ReplicateResponseError ) {
87+ throw err ;
88+ }
8289
83- if (
84- ! err . status ||
85- ( Math . floor ( err . status / 100 ) !== 5 && err . status !== 429 )
86- ) {
87- throw err ;
88- }
90+ if (
91+ ! err . status ||
92+ ( Math . floor ( err . status / 100 ) !== 5 && err . status !== 429 )
93+ ) {
94+ throw err ;
95+ }
8996
90- errorCount += 1 ;
97+ errorCount += 1 ;
9198
92- onTemporaryError ( err ) ;
99+ onTemporaryError ( err ) ;
93100
94- pollingInterval = backoffFn ( errorCount ) ;
101+ pollingInterval = backoffFn ( errorCount ) ;
102+ }
103+ }
104+ } catch ( err ) {
105+ if ( cancelOnFatalError ) {
106+ // We intentionally don't await this, so we don't block.
107+ prediction
108+ . cancel ( )
109+ . catch ( ( e ) => {
110+ onCancelError ( e ) ;
111+
112+ throw e ;
113+ } )
114+ . then ( ( ) => {
115+ onCancel ( ) ;
116+ } ) ;
95117 }
118+
119+ throw err ;
96120 }
97121
98122 return prediction ;
0 commit comments