Skip to content

Commit 195d096

Browse files
committed
Add support for block: or wait: true to run
1 parent 61f60ac commit 195d096

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

index.js

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class Replicate {
133133
* @param {string} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version"
134134
* @param {object} options
135135
* @param {object} options.input - Required. An object with the model inputs
136-
* @param {object} [options.wait] - Options for waiting for the prediction to finish
136+
* @param {object} [options.wait] - Options for waiting for the prediction to finish. If `wait` is explicitly true, the function will block and wait for the prediction to finish.
137137
* @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 500
138138
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
139139
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
@@ -144,20 +144,26 @@ class Replicate {
144144
* @returns {Promise<object>} - Resolves with the output of running the model
145145
*/
146146
async run(ref, options, progress) {
147+
let { block } = options;
147148
const { wait, signal, ...data } = options;
148149

150+
// Block if `block` is explicitly true or if `wait` is explicitly true
151+
block = block || (block === undefined && wait === true);
152+
149153
const identifier = ModelVersionIdentifier.parse(ref);
150154

151155
let prediction;
152156
if (identifier.version) {
153157
prediction = await this.predictions.create({
154158
...data,
155159
version: identifier.version,
160+
block,
156161
});
157162
} else if (identifier.owner && identifier.name) {
158163
prediction = await this.predictions.create({
159164
...data,
160165
model: `${identifier.owner}/${identifier.name}`,
166+
block,
161167
});
162168
} else {
163169
throw new Error("Invalid model version identifier");

0 commit comments

Comments
 (0)