Skip to content

Commit a34a645

Browse files
committed
Examples: add replicate.use() demo script
1 parent f41cc05 commit a34a645

File tree

3 files changed

+108
-5
lines changed

3 files changed

+108
-5
lines changed

examples/use_demo.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Example of using the experimental replicate.use() interface
4+
"""
5+
6+
import replicate
7+
8+
print("Testing replicate.use() functionality...")
9+
10+
# Test 1: Simple text model
11+
print("\n1. Testing simple text model...")
12+
try:
13+
hello_world = replicate.use("replicate/hello-world")
14+
result = hello_world(text="Alice")
15+
print(f"Result: {result}")
16+
except Exception as e:
17+
print(f"Error: {type(e).__name__}: {e}")
18+
19+
# Test 2: Image generation model
20+
print("\n2. Testing image generation model...")
21+
try:
22+
from replicate.lib._predictions_use import get_path_url
23+
24+
flux_dev = replicate.use("black-forest-labs/flux-dev")
25+
outputs = flux_dev(
26+
prompt="a cat wearing a wizard hat, digital art",
27+
num_outputs=1,
28+
aspect_ratio="1:1",
29+
output_format="webp",
30+
guidance=3.5,
31+
num_inference_steps=28,
32+
)
33+
print(f"Generated output: {outputs}")
34+
if isinstance(outputs, list):
35+
print(f"Generated {len(outputs)} image(s)")
36+
for i, output in enumerate(outputs):
37+
print(f" Image {i}: {output}")
38+
# Get the URL without downloading
39+
url = get_path_url(output)
40+
if url:
41+
print(f" URL: {url}")
42+
else:
43+
print(f"Single output: {outputs}")
44+
url = get_path_url(outputs)
45+
if url:
46+
print(f" URL: {url}")
47+
except Exception as e:
48+
print(f"Error: {type(e).__name__}: {e}")
49+
import traceback
50+
51+
traceback.print_exc()
52+
53+
# Test 3: Language model with streaming
54+
print("\n3. Testing language model with streaming...")
55+
try:
56+
llama = replicate.use("meta/meta-llama-3-8b-instruct", streaming=True)
57+
output = llama(prompt="Write a haiku about Python programming", max_tokens=50)
58+
print("Streaming output:")
59+
for chunk in output:
60+
print(chunk, end="", flush=True)
61+
print()
62+
except Exception as e:
63+
print(f"Error: {type(e).__name__}: {e}")
64+
import traceback
65+
66+
traceback.print_exc()
67+
68+
# Test 4: Using async
69+
print("\n4. Testing async functionality...")
70+
import asyncio
71+
72+
73+
async def test_async():
74+
try:
75+
hello_world = replicate.use("replicate/hello-world", use_async=True)
76+
result = await hello_world(text="Bob")
77+
print(f"Async result: {result}")
78+
79+
print("\n4b. Testing async streaming...")
80+
llama = replicate.use("meta/meta-llama-3-8b-instruct", streaming=True, use_async=True)
81+
output = await llama(prompt="Write a short poem about async/await", max_tokens=50)
82+
print("Async streaming output:")
83+
async for chunk in output:
84+
print(chunk, end="", flush=True)
85+
print()
86+
except Exception as e:
87+
print(f"Error: {type(e).__name__}: {e}")
88+
import traceback
89+
90+
traceback.print_exc()
91+
92+
93+
asyncio.run(test_async())
94+
95+
print("\nDone!")

src/replicate/_module_client.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,13 @@ def __load__(self) -> PredictionsResource:
8787
def _run(*args, **kwargs):
8888
return _load_client().run(*args, **kwargs)
8989

90-
def _use(*args, **kwargs):
91-
return _load_client().use(*args, **kwargs)
90+
def _use(ref, *, hint=None, streaming=False, use_async=False, **kwargs):
91+
if use_async:
92+
# For async, we need to use AsyncReplicate instead
93+
from ._client import AsyncReplicate
94+
client = AsyncReplicate()
95+
return client.use(ref, hint=hint, streaming=streaming, **kwargs)
96+
return _load_client().use(ref, hint=hint, streaming=streaming, **kwargs)
9297

9398
run = _run
9499
use = _use

src/replicate/lib/_predictions_use.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,10 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
466466
version_id = str(version)
467467
prediction = self._client.predictions.create(version=version_id, input=processed_inputs)
468468
else:
469-
prediction = self._client.models.predictions.create(model=self._model, input=processed_inputs)
469+
model = self._model
470+
prediction = self._client.models.predictions.create(
471+
model_owner=model.owner or "", model_name=model.name or "", input=processed_inputs
472+
)
470473

471474
return Run(
472475
client=self._client,
@@ -600,7 +603,7 @@ async def logs(self) -> Optional[str]:
600603
"""
601604
Fetch and return the logs from the prediction asynchronously.
602605
"""
603-
self._prediction = await self._client.predictions.async_get(prediction_id=self._prediction.id)
606+
self._prediction = await self._client.predictions.get(prediction_id=self._prediction.id)
604607

605608
return self._prediction.logs
606609

@@ -623,7 +626,7 @@ async def _async_output_iterator(self) -> AsyncIterator[Any]:
623626
import asyncio
624627

625628
await asyncio.sleep(self._client.poll_interval)
626-
self._prediction = await self._client.predictions.async_get(prediction_id=self._prediction.id)
629+
self._prediction = await self._client.predictions.get(prediction_id=self._prediction.id)
627630

628631
if self._prediction.status == "failed":
629632
raise ModelError(self._prediction)

0 commit comments

Comments
 (0)