@@ -7,48 +7,83 @@ jest.unstable_mockModule("node-fetch", () => ({
77} ) ) ;
88
99import { ReplicateResponseError } from "./errors.js" ;
10+ import Model from "./Model.js" ;
1011import Prediction , { PredictionStatus } from "./Prediction.js" ;
1112
1213const { default : ReplicateClient } = await import ( "./ReplicateClient.js" ) ;
1314
1415let client ;
15- let version ;
16+ let model ;
1617
1718beforeEach ( ( ) => {
1819 process . env . REPLICATE_API_TOKEN = "test-token-from-env" ;
1920
2021 client = new ReplicateClient ( { } ) ;
21- version = client . version ( "test-version" ) ;
22+ model = client . model ( "test-owner/test-name@testversion" ) ;
23+ } ) ;
24+
25+ describe ( "load()" , ( ) => {
26+ it ( "makes request to get model version" , async ( ) => {
27+ jest . spyOn ( client , "request" ) . mockResolvedValue ( {
28+ id : "testversion" ,
29+ } ) ;
30+
31+ await model . load ( ) ;
32+
33+ expect ( client . request ) . toHaveBeenCalledWith (
34+ "GET /v1/models/test-owner/test-name/versions/testversion"
35+ ) ;
36+ } ) ;
37+
38+ it ( "returns Model" , async ( ) => {
39+ jest . spyOn ( client , "request" ) . mockResolvedValue ( {
40+ id : "testversion" ,
41+ } ) ;
42+
43+ const returnedModel = await model . load ( ) ;
44+
45+ expect ( returnedModel ) . toBeInstanceOf ( Model ) ;
46+ } ) ;
47+
48+ it ( "updates Model in place" , async ( ) => {
49+ jest . spyOn ( client , "request" ) . mockResolvedValue ( {
50+ id : "testversion" ,
51+ } ) ;
52+
53+ const returnedModel = await model . load ( ) ;
54+
55+ expect ( returnedModel ) . toBe ( model ) ;
56+ } ) ;
2257} ) ;
2358
2459describe ( "predict()" , ( ) => {
2560 it ( "makes request to create prediction" , async ( ) => {
26- jest . spyOn ( version , "createPrediction" ) . mockResolvedValue (
61+ jest . spyOn ( model , "createPrediction" ) . mockResolvedValue (
2762 new Prediction (
2863 {
29- id : "test-prediction " ,
64+ id : "testprediction " ,
3065 status : PredictionStatus . SUCCEEDED ,
3166 } ,
3267 client
3368 )
3469 ) ;
3570
36- await version . predict (
71+ await model . predict (
3772 { text : "test text" } ,
3873 { } ,
3974 { defaultPollingInterval : 0 }
4075 ) ;
4176
42- expect ( version . createPrediction ) . toHaveBeenCalledWith ( {
77+ expect ( model . createPrediction ) . toHaveBeenCalledWith ( {
4378 text : "test text" ,
4479 } ) ;
4580 } ) ;
4681
4782 it ( "uses created prediction's ID to fetch update" , async ( ) => {
48- jest . spyOn ( version , "createPrediction" ) . mockResolvedValue (
83+ jest . spyOn ( model , "createPrediction" ) . mockResolvedValue (
4984 new Prediction (
5085 {
51- id : "test-prediction " ,
86+ id : "testprediction " ,
5287 status : PredictionStatus . STARTING ,
5388 } ,
5489 client
@@ -75,20 +110,20 @@ describe("predict()", () => {
75110 . spyOn ( client , "request" )
76111 . mockImplementation ( ( action ) => requestMockReturnValues [ action ] ) ;
77112
78- await version . predict (
113+ await model . predict (
79114 { text : "test text" } ,
80115 { } ,
81116 { defaultPollingInterval : 0 }
82117 ) ;
83118
84- expect ( client . prediction ) . toHaveBeenCalledWith ( "test-prediction " ) ;
119+ expect ( client . prediction ) . toHaveBeenCalledWith ( "testprediction " ) ;
85120 } ) ;
86121
87122 it ( "polls prediction status until success" , async ( ) => {
88- jest . spyOn ( version , "createPrediction" ) . mockResolvedValue (
123+ jest . spyOn ( model , "createPrediction" ) . mockResolvedValue (
89124 new Prediction (
90125 {
91- id : "test-prediction " ,
126+ id : "testprediction " ,
92127 status : PredictionStatus . STARTING ,
93128 } ,
94129 client
@@ -98,21 +133,21 @@ describe("predict()", () => {
98133 const predictionLoadResults = [
99134 new Prediction (
100135 {
101- id : "test-prediction " ,
136+ id : "testprediction " ,
102137 status : PredictionStatus . PROCESSING ,
103138 } ,
104139 client
105140 ) ,
106141 new Prediction (
107142 {
108- id : "test-prediction " ,
143+ id : "testprediction " ,
109144 status : PredictionStatus . PROCESSING ,
110145 } ,
111146 client
112147 ) ,
113148 new Prediction (
114149 {
115- id : "test-prediction " ,
150+ id : "testprediction " ,
116151 status : PredictionStatus . SUCCEEDED ,
117152 } ,
118153 client
@@ -122,14 +157,14 @@ describe("predict()", () => {
122157 const predictionLoad = jest . fn ( ( ) => predictionLoadResults . shift ( ) ) ;
123158
124159 jest . spyOn ( client , "prediction" ) . mockImplementation ( ( ) => {
125- const prediction = new Prediction ( { id : "test-prediction " } , client ) ;
160+ const prediction = new Prediction ( { id : "testprediction " } , client ) ;
126161
127162 jest . spyOn ( prediction , "load" ) . mockImplementation ( predictionLoad ) ;
128163
129164 return prediction ;
130165 } ) ;
131166
132- const prediction = await version . predict (
167+ const prediction = await model . predict (
133168 { text : "test text" } ,
134169 { } ,
135170 { defaultPollingInterval : 0 }
@@ -140,10 +175,10 @@ describe("predict()", () => {
140175 } ) ;
141176
142177 it ( "retries polling on error" , async ( ) => {
143- jest . spyOn ( version , "createPrediction" ) . mockResolvedValue (
178+ jest . spyOn ( model , "createPrediction" ) . mockResolvedValue (
144179 new Prediction (
145180 {
146- id : "test-prediction " ,
181+ id : "testprediction " ,
147182 status : PredictionStatus . STARTING ,
148183 } ,
149184 client
@@ -172,7 +207,7 @@ describe("predict()", () => {
172207 ( ) =>
173208 new Prediction (
174209 {
175- id : "test-prediction " ,
210+ id : "testprediction " ,
176211 status : PredictionStatus . SUCCEEDED ,
177212 } ,
178213 client
@@ -182,15 +217,15 @@ describe("predict()", () => {
182217 const predictionLoad = jest . fn ( ( ) => predictionLoadResults . shift ( ) ( ) ) ;
183218
184219 jest . spyOn ( client , "prediction" ) . mockImplementation ( ( ) => {
185- const prediction = new Prediction ( { id : "test-prediction " } , client ) ;
220+ const prediction = new Prediction ( { id : "testprediction " } , client ) ;
186221
187222 jest . spyOn ( prediction , "load" ) . mockImplementation ( predictionLoad ) ;
188223
189224 return prediction ;
190225 } ) ;
191226 const backoffFn = jest . fn ( ( ) => 0 ) ;
192227
193- const prediction = await version . predict (
228+ const prediction = await model . predict (
194229 { text : "test text" } ,
195230 { } ,
196231 { defaultPollingInterval : 0 , backoffFn }
@@ -205,14 +240,14 @@ describe("predict()", () => {
205240describe ( "createPrediction()" , ( ) => {
206241 it ( "makes request to create prediction" , async ( ) => {
207242 jest . spyOn ( client , "request" ) . mockResolvedValue ( {
208- id : "test-prediction " ,
243+ id : "testprediction " ,
209244 status : PredictionStatus . SUCCEEDED ,
210245 } ) ;
211246
212- await version . createPrediction ( { text : "test text" } ) ;
247+ await model . createPrediction ( { text : "test text" } ) ;
213248
214249 expect ( client . request ) . toHaveBeenCalledWith ( "POST /v1/predictions" , {
215- version : "test-version " ,
250+ version : "testversion " ,
216251 input : { text : "test text" } ,
217252 } ) ;
218253 } ) ;
0 commit comments