Skip to content

Commit ab12403

Browse files
committed
Refactored ollama adapters to support non localhost URLs
1 parent 5cbb631 commit ab12403

File tree

7 files changed

+76
-24
lines changed

7 files changed

+76
-24
lines changed

src/talkpipe/llm/chat.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import logging
1010
from pydantic import BaseModel
1111

12-
from talkpipe.llm.config import TALKPIPE_MODEL_NAME, TALKPIPE_SOURCE
12+
from talkpipe.util.constants import TALKPIPE_MODEL_NAME, TALKPIPE_SOURCE
1313
from talkpipe.util.data_manipulation import extract_property
1414

1515

@@ -31,6 +31,11 @@ class LLMPrompt(AbstractSegment):
3131
and TALKPIPE_default_source). If those are not set, the values will be loaded
3232
from the configuration file (~/.talkpipe.toml). If none of those are set, an
3333
error will be raised.
34+
35+
Currently supported sources are "ollama," "openai," and "anthropic." If
36+
you specify "ollama," you can optionally set the OLLAMA_SERVER_URL environment
37+
variable or configuration value to point to a different server. By default,
38+
ollama assumes localhost.
3439
"""
3540

3641
def __init__(

src/talkpipe/llm/config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,3 @@ def getEmbeddingSources()->List[str]:
3434
return list(_embeddingAdapter.keys())
3535

3636

37-
TALKPIPE_MODEL_NAME = "default_model_name"
38-
TALKPIPE_SOURCE = "default_model_source"

src/talkpipe/llm/embedding_adapters.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List
2-
32
import numpy as np
3+
from talkpipe.util.config import get_config
4+
from talkpipe.util.constants import OLLAMA_SERVER_URL
45

56
class AbstractEmbeddingAdapter:
67
"""Abstract class for embedding text.
@@ -44,8 +45,9 @@ def __call__(self, text: str) -> List[float]:
4445
class OllamaEmbedderAdapter(AbstractEmbeddingAdapter):
4546
"""Embedding adapter for Ollama"""
4647

47-
def __init__(self, model: str):
48+
def __init__(self, model: str, server_url: str = None):
4849
super().__init__(model, "ollama")
50+
self._server_url = server_url
4951

5052
def execute(self, text: str) -> List[float]:
5153
try:
@@ -55,7 +57,11 @@ def execute(self, text: str) -> List[float]:
5557
"Ollama is not installed. Please install it with: pip install talkpipe[ollama]"
5658
)
5759

58-
response = ollama.embed(
60+
server_url = self._server_url
61+
if not server_url:
62+
server_url = get_config().get(OLLAMA_SERVER_URL, None)
63+
client = ollama.Client(server_url) if server_url else ollama
64+
response = client.embed(
5965
model=self.model_name,
6066
input=text
6167
)

src/talkpipe/llm/prompt_adapters.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from abc import ABC, abstractmethod
2+
import logging
13
import json
24
from pydantic import BaseModel
3-
import logging
4-
5-
from abc import ABC, abstractmethod
5+
from talkpipe.util.config import get_config
6+
from talkpipe.util.constants import OLLAMA_SERVER_URL
67

78
logger = logging.getLogger(__name__)
89

@@ -73,13 +74,25 @@ def is_available(self) -> bool:
7374
class OllamaPromptAdapter(AbstractLLMPromptAdapter):
7475
"""Prompt adapter for Ollama
7576
77+
Note: By default, ollama assumes localhost for the ollama server.
78+
If your server is running elsewhere, you can set the OLLAMA_SERVER_URL
79+
environment variable or in the configuration, or pass the server_url
80+
parameter.
81+
7682
"""
7783

78-
def __init__(self, model: str, system_prompt: str = "You are a helpful assistant.", multi_turn: bool = True, temperature: float = None, output_format: BaseModel = None):
84+
def __init__(self,
85+
model: str,
86+
system_prompt: str = "You are a helpful assistant.",
87+
multi_turn: bool = True,
88+
temperature: float = None,
89+
output_format: BaseModel = None,
90+
server_url: str = None):
7991
super().__init__(model, "ollama", system_prompt, multi_turn, temperature, output_format)
8092
# Ollama uses 0.5 as default when temperature is not specified
8193
if self._temperature is None:
8294
self._temperature = 0.5
95+
self._server_url = server_url
8396

8497
def execute(self, prompt: str) -> str:
8598
"""Execute the chat model.
@@ -97,7 +110,11 @@ def execute(self, prompt: str) -> str:
97110
self._messages.append({"role": "user", "content": prompt})
98111

99112
logger.debug(f"Sending chat request to Ollama model {self._model_name}")
100-
response = ollama.chat(
113+
server_url = self._server_url
114+
if not server_url:
115+
server_url = get_config().get(OLLAMA_SERVER_URL, None)
116+
client = ollama.Client(server_url) if server_url else ollama
117+
response = client.chat(
101118
self._model_name,
102119
messages=[self._system_message] + self._messages,
103120
format=self._output_format.model_json_schema() if self._output_format else None,

src/talkpipe/util/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
TALKPIPE_MODEL_NAME = "default_model_name"
2+
TALKPIPE_SOURCE = "default_model_source"
3+
OLLAMA_SERVER_URL = "OLLAMA_SERVER_URL"

tests/talkpipe/llm/test_chat.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,29 @@ def test_chat_property(requires_ollama):
161161
assert isinstance(response, str)
162162
assert "HELLO" in response
163163

164+
def test_chat_ollama_custom_host(requires_ollama, monkeypatch):
165+
# Test that OLLAMA_SERVER_URL environment variable is respected
166+
custom_url = "http://custom-ollama:11434"
167+
monkeypatch.setenv("TALKPIPE_OLLAMA_SERVER_URL", custom_url)
168+
169+
chat = LLMPrompt(model="llama3.2", source="ollama", temperature=0.0)
170+
chat = chat.as_function(single_in=True, single_out=True)
171+
# Mock the ollama.Client to verify custom URL is passed
172+
def mock_client_init(self, host=None, **kwargs):
173+
assert host == custom_url, f"Expected host {custom_url}, got {host}"
174+
# Store original attributes that might be needed
175+
self.host = host
176+
177+
monkeypatch.setattr("ollama.Client.__init__", mock_client_init)
178+
179+
# Test that the chat works with custom URL
180+
try:
181+
response = chat("Hello, this is a test.")
182+
except Exception:
183+
# The mock will prevent actual connection, but we verified the URL was passed correctly
184+
pass
185+
186+
164187
class AnAnswer(BaseModel):
165188
ans: int
166189

tests/talkpipe/util/test_util.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -179,43 +179,43 @@ def test_get_config(tmp_path, monkeypatch):
179179
assert talkpipe.util.config._config is None
180180

181181
# Add environment variables in a controlled way
182-
monkeypatch.setenv("TALKPIPE_" + talkpipe.llm.config.TALKPIPE_MODEL_NAME, "llama3.1")
183-
monkeypatch.setenv("TALKPIPE_" + talkpipe.llm.config.TALKPIPE_SOURCE, "ollama")
184-
182+
monkeypatch.setenv("TALKPIPE_" + talkpipe.util.constants.TALKPIPE_MODEL_NAME, "llama3.1")
183+
monkeypatch.setenv("TALKPIPE_" + talkpipe.util.constants.TALKPIPE_SOURCE, "ollama")
184+
185185
# When no configuration file exists, get_config will initialize the config.
186186
cfg = talkpipe.util.config.get_config(path=test_path)
187187
assert talkpipe.util.config._config is not None
188188
assert len(cfg) == 2
189-
assert talkpipe.llm.config.TALKPIPE_MODEL_NAME in cfg
190-
assert talkpipe.llm.config.TALKPIPE_SOURCE in cfg
189+
assert talkpipe.util.constants.TALKPIPE_MODEL_NAME in cfg
190+
assert talkpipe.util.constants.TALKPIPE_SOURCE in cfg
191191

192192
# Write a configuration file with values that differ from the env vars.
193193
with open(test_path, "w") as file:
194194
file.write(
195195
"""
196196
%s = "silly"
197197
%s = "beans"
198-
""" % (talkpipe.llm.config.TALKPIPE_MODEL_NAME, talkpipe.llm.config.TALKPIPE_SOURCE)
198+
""" % (talkpipe.util.constants.TALKPIPE_MODEL_NAME, talkpipe.util.constants.TALKPIPE_SOURCE)
199199
)
200200

201201
# Reload the config: env vars should override the values in the file.
202202
cfg = talkpipe.util.config.get_config(path=test_path, reload=True)
203203
assert len(cfg) == 2
204204
# The environment variable values take precedence over the file.
205-
assert cfg[talkpipe.llm.config.TALKPIPE_MODEL_NAME] == "llama3.1"
206-
assert cfg[talkpipe.llm.config.TALKPIPE_SOURCE] == "ollama"
205+
assert cfg[talkpipe.util.constants.TALKPIPE_MODEL_NAME] == "llama3.1"
206+
assert cfg[talkpipe.util.constants.TALKPIPE_SOURCE] == "ollama"
207207

208208
# Remove environment variables to test file-only mode
209-
monkeypatch.delenv("TALKPIPE_" + talkpipe.llm.config.TALKPIPE_MODEL_NAME)
210-
monkeypatch.delenv("TALKPIPE_" + talkpipe.llm.config.TALKPIPE_SOURCE)
209+
monkeypatch.delenv("TALKPIPE_" + talkpipe.util.constants.TALKPIPE_MODEL_NAME)
210+
monkeypatch.delenv("TALKPIPE_" + talkpipe.util.constants.TALKPIPE_SOURCE)
211211

212212
talkpipe.util.config.reset_config()
213213
# Here, ignore_env=True tells get_config to load values directly from the file.
214214
cfg = talkpipe.util.config.get_config(path=test_path, reload=True, ignore_env=True)
215215
assert len(cfg) == 2
216-
assert cfg[talkpipe.llm.config.TALKPIPE_MODEL_NAME] == "silly"
217-
assert cfg[talkpipe.llm.config.TALKPIPE_SOURCE] == "beans"
218-
216+
assert cfg[talkpipe.util.constants.TALKPIPE_MODEL_NAME] == "silly"
217+
assert cfg[talkpipe.util.constants.TALKPIPE_SOURCE] == "beans"
218+
219219
finally:
220220
# Ensure we reset everything at the end
221221
talkpipe.util.config.reset_config()

0 commit comments

Comments
 (0)