Skip to content

Commit 24a5b7d

Browse files
authored
Merge pull request #475 from sudoleg/copilot/add-ollama-llm-provider
Add Ollama provider option for LLMs and embeddings
2 parents 95b0488 + 6390da3 commit 24a5b7d

File tree

11 files changed

+380
-92
lines changed

11 files changed

+380
-92
lines changed

README.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
## Features :sparkles:
88

9-
YouTubeGPT lets you **summarize and chat (Q&A)** with YouTube videos. Its features include:
9+
YouTubeGPT is a web app that can be run fully locally and lets you **summarize and chat (Q&A)** with YouTube videos. You can either use OpenAI's API or a (local) Ollama instance.
10+
11+
YouTubeGPT's features include:
1012

1113
### :writing_hand: Provide a custom prompt for summaries [**VIEW DEMO**](https://youtu.be/rJqx3qvebws)
1214

@@ -21,8 +23,9 @@ YouTubeGPT lets you **summarize and chat (Q&A)** with YouTube videos. Its featur
2123
- the summaries and answers can be saved to a library accessible at a separate page!
2224
- additionally, summaries and answers can be exported/downloaded as Markdown files!
2325

24-
### :robot: Choose from different OpenAI models
26+
### :robot: Choose provider and models
2527

28+
- choose between OpenAI's API or a (local) Ollama instance
2629
- currently available: ChatGPT 4-5 (incl. nano & mini) and *continuously updated* with new models
2730
- by choosing a different model, you can summarize even longer videos and get better responses
2831

@@ -36,7 +39,11 @@ YouTubeGPT lets you **summarize and chat (Q&A)** with YouTube videos. Its featur
3639

3740
## Installation & usage
3841

39-
No matter how you choose to run the app, you will first need to get an OpenAI API-Key. This is very straightforward and free. Have a look at [their instructions](https://platform.openai.com/docs/quickstart/account-setup) to get started.
42+
If you want to use OpenAI's API, you will first need to get an OpenAI API-Key. This is very straightforward and free. Have a look at [their instructions](https://platform.openai.com/docs/quickstart/account-setup) to get started.
43+
44+
If you want to use Ollama, you need to have an Ollama server running locally or remotely. You can download Ollama for macOS, Linux, or Windows [on their website](https://ollama.com/download). Make sure the server is reachable either on the default port `11434` or set the `OLLAMA_HOST` environment variable to point to your Ollama server. Also, you need to **pull the models** you want to use.
45+
46+
> **Note**: Ollama limits the context window to 4k tokens by default. I strongly recommend to adjust it to at least 16k tokens. This can be done in the Ollama app settings.
4047
4148
### Run with Docker
4249

docker-compose.yml

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,25 @@ services:
2828
# replace with your OpenAI API key or the name of the environment
2929
# variable that stores it on your PC
3030
- OPENAI_API_KEY=${OPENAI_YOUTUBEGPT_API_KEY}
31+
- OLLAMA_HOST=http://host.docker.internal:11434
3132
ports:
3233
- "8501:8501"
3334
networks:
3435
- net
3536

36-
ollama:
37-
image: ollama/ollama:0.13.5
38-
container_name: ollama
39-
volumes:
40-
- ollama:/root/.ollama
41-
ports:
42-
- "11434:11434"
43-
restart: unless-stopped
44-
networks:
45-
- net
37+
#ollama:
38+
# image: ollama/ollama:0.13.3
39+
# container_name: ollama
40+
# volumes:
41+
# - ollama:/root/.ollama
42+
# ports:
43+
# - "11434:11434"
44+
# restart: unless-stopped
45+
# networks:
46+
# - net
4647

4748
volumes:
4849
chroma:
4950
driver: local
50-
ollama:
51-
driver: local
51+
# ollama:
52+
# driver: local

modules/helpers.py

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import os
44
import re
55
from pathlib import Path
6-
from typing import List, Literal
6+
from typing import List, Literal, Optional
77

8+
import ollama
89
import openai
910
import streamlit as st
1011
import tiktoken
@@ -68,13 +69,17 @@ def get_available_models(
6869
get_default_config_value(f"available_models.{model_type}")
6970
)
7071

72+
def _filter_available(models: List[str]) -> List[str]:
73+
return [m for m in selectable_model_ids if m in models]
74+
75+
if not api_key and not os.getenv("OPENAI_API_KEY"):
76+
return selectable_model_ids
77+
7178
# AVAILABLE_MODEL_IDS env var stores all the model IDs available to the user as a list (separated by a comma)
7279
# the env var is set programatically below
7380
available_model_ids = os.getenv("AVAILABLE_MODEL_IDS")
7481
if available_model_ids:
75-
return filter(
76-
lambda m: m in available_model_ids.split(","), selectable_model_ids
77-
)
82+
return _filter_available(available_model_ids.split(","))
7883

7984
try:
8085
available_model_ids: list = [model.id for model in openai.models.list()]
@@ -94,7 +99,7 @@ def get_available_models(
9499
# set the AVAILABLE_MODEL_IDS env var, so that the list of available models
95100
# doesn't have to be fetched every time
96101
os.environ["AVAILABLE_MODEL_IDS"] = ",".join(available_model_ids)
97-
return filter(lambda m: m in available_model_ids, selectable_model_ids)
102+
return _filter_available(available_model_ids)
98103

99104

100105
def get_default_config_value(
@@ -203,14 +208,16 @@ def save_response_as_file(
203208
logging.info("File saved at: %s", file_path)
204209

205210

206-
def get_preffered_languages():
211+
def get_preferred_languages():
212+
"""Return preferred languages for transcripts."""
207213
# TODO: return from configuration object or config.json
208214
return ["en-US", "en", "de"]
209215

210216

211-
def num_tokens_from_string(string: str, model: str = "gpt-4o-mini") -> int:
217+
# TODO: handle Ollama models as well or fallback to other token count method
218+
def num_tokens_from_string(string: str, model: str = "gpt-4.1-nano") -> int:
212219
"""
213-
Returns the number of tokens in a text string.
220+
Returns the number of tokens in a text string for OpenAI models.
214221
215222
Args:
216223
string (str): The string to count tokens in.
@@ -238,3 +245,55 @@ def is_environment_prod():
238245
if os.getenv("ENVIRONMENT") == "production":
239246
return True
240247
return False
248+
249+
250+
def get_ollama_host() -> str:
251+
"""Return the configured Ollama host."""
252+
return os.getenv("OLLAMA_HOST", "http://localhost:11434")
253+
254+
255+
def is_ollama_available(host: Optional[str] = None) -> bool:
256+
"""Checks whether an Ollama server is reachable."""
257+
ollama_host = host or get_ollama_host()
258+
try:
259+
ollama.Client(host=ollama_host).list()
260+
except Exception as e:
261+
logging.error("Ollama connection check failed: %s", str(e))
262+
return False
263+
return True
264+
265+
266+
def _is_embedding_model(model: dict) -> bool:
267+
"""Determine whether an Ollama model is an embedding model."""
268+
details = model.get("details", {})
269+
family = details.get("family", "").lower()
270+
model_type = details.get("model_type", "").lower()
271+
name = model.get("model", "").lower()
272+
return "embed" in family or "embedding" in model_type or "embed" in name
273+
274+
275+
def get_ollama_models(
276+
model_type: Literal["gpts", "embeddings"], host: Optional[str] = None
277+
) -> List[str]:
278+
"""Returns available Ollama models filtered by type."""
279+
ollama_host = host or get_ollama_host()
280+
try:
281+
models = ollama.Client(host=ollama_host).list().get("models", [])
282+
except Exception as e:
283+
logging.error("Could not list Ollama models: %s", str(e))
284+
return []
285+
286+
if model_type == "embeddings":
287+
return [model["model"] for model in models if _is_embedding_model(model)]
288+
return [model["model"] for model in models if not _is_embedding_model(model)]
289+
290+
291+
def pull_ollama_model(model_name: str, host: Optional[str] = None) -> bool:
292+
"""Triggers pulling an Ollama model; returns True on success."""
293+
ollama_host = host or get_ollama_host()
294+
try:
295+
ollama.Client(host=ollama_host).pull(model=model_name, stream=False)
296+
except Exception as e:
297+
logging.error("Failed to pull Ollama model %s: %s", model_name, str(e))
298+
return False
299+
return True

modules/summary.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import logging
22

3+
import ollama
34
from langchain.messages import HumanMessage, SystemMessage
4-
from langchain_openai import ChatOpenAI
5+
from langchain_core.language_models import BaseChatModel
56

67
from .helpers import num_tokens_from_string, read_file
78

@@ -10,7 +11,7 @@
1011
USER_PROMPT_TEMPLATE = read_file("prompts/summary_user_prompt.txt")
1112

1213
# info about OpenAI's GPTs context windows: https://platform.openai.com/docs/models
13-
CONTEXT_WINDOWS = {
14+
OPENAI_CONTEXT_WINDOWS = {
1415
"gpt-3.5-turbo": {"total": 16385, "output": 4096},
1516
"gpt-4": {"total": 8192, "output": 4096},
1617
"gpt-4-turbo": {"total": 128000, "output": 4096},
@@ -40,13 +41,13 @@ def log_error(self):
4041
logging.error("Transcript too long for %s.", self.model_name, exc_info=True)
4142

4243

43-
def get_transcript_summary(transcript_text: str, llm: ChatOpenAI, **kwargs):
44+
def get_transcript_summary(transcript_text: str, llm: BaseChatModel, **kwargs):
4445
"""
4546
Generates a summary from a video transcript using a language model.
4647
4748
Args:
4849
transcript_text (str): The full transcript text of the video.
49-
llm (ChatOpenAI): The language model instance to use for generating the summary.
50+
llm (BaseChatModel): The language model instance to use for generating the summary.
5051
**kwargs: Optional keyword arguments.
5152
- custom_prompt (str): A custom prompt to replace the default summary request.
5253
@@ -68,30 +69,39 @@ def get_transcript_summary(transcript_text: str, llm: ChatOpenAI, **kwargs):
6869
---
6970
{transcript_text}
7071
---
71-
"""
72-
72+
"""
7373
else:
7474
user_prompt = USER_PROMPT_TEMPLATE.format(transcript_text=transcript_text)
7575

76+
if llm.name not in OPENAI_CONTEXT_WINDOWS.keys():
77+
model_details = ollama.show(model=llm.name)
78+
model_info = model_details.get("modelinfo", {})
79+
general_arch = model_info.get("general.architecture", "")
80+
max_context_length = model_info.get(f"{general_arch}.context_length", 4096)
81+
else:
82+
max_context_length = OPENAI_CONTEXT_WINDOWS[llm.name]["total"]
83+
7684
# if the number of tokens in the transcript (plus the number of tokens in the prompt) exceed the model's context window, an exception is raised
77-
if (
78-
num_tokens_from_string(string=user_prompt, model=llm.model_name)
79-
+ num_tokens_from_string(string=SYSTEM_PROMPT, model=llm.model_name)
80-
> CONTEXT_WINDOWS[llm.model_name]["total"]
81-
):
85+
total_tokens = num_tokens_from_string(
86+
string=user_prompt, model=llm.name
87+
) + num_tokens_from_string(string=SYSTEM_PROMPT, model=llm.name)
88+
if total_tokens > max_context_length:
8289
raise TranscriptTooLongForModelException(
83-
message=f"Your transcript exceeds the context window of the chosen model ({llm.model_name}), which is {CONTEXT_WINDOWS[llm.model_name]['total']} tokens. "
90+
message=f"Your transcript exceeds the context window of the chosen model ({llm.name}), which is {max_context_length} tokens. "
8491
"Consider the following options:\n"
85-
"1. Choose another model with larger context window (such as gpt-4o).\n"
92+
"1. Choose another model with larger context window.\n"
8693
"2. Use the 'Chat' feature to ask specific questions about the video. There you won't be limited by the number of tokens.\n\n"
87-
"You can get more information on context windows for different models in the [official OpenAI documentation about models](https://platform.openai.com/docs/models).",
88-
model_name=llm.model_name,
94+
"You can get more information on context windows for different OpenAI models in the [official documentation](https://platform.openai.com/docs/models).",
95+
model_name=llm.name,
8996
)
9097

9198
messages = [
9299
SystemMessage(content=SYSTEM_PROMPT),
93100
HumanMessage(content=user_prompt),
94101
]
95102

103+
logging.info(
104+
"Generating summary using model: %s. Total tokens: %d", llm.name, total_tokens
105+
)
96106
response = llm.invoke(messages)
97107
return response.content

modules/ui.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,19 @@
55
from modules.helpers import (
66
get_available_models,
77
get_default_config_value,
8+
get_ollama_models,
89
is_api_key_set,
910
is_api_key_valid,
11+
is_ollama_available,
1012
)
1113

1214
GENERAL_ERROR_MESSAGE = "An unexpected error occurred. If you are a developer and run the app locally, you can view the logs to see details about the error."
1315

1416

1517
def display_api_key_warning():
1618
"""Checks whether an API key is provided and displays warning if not."""
19+
if st.session_state.get("llm_provider", "OpenAI") == "Ollama":
20+
return
1721
if not is_api_key_set():
1822
st.warning(
1923
""":warning: It seems you haven't provided an API key yet. Make sure to do so by providing it in the settings (sidebar)
@@ -58,18 +62,50 @@ def display_model_settings_sidebar():
5862
For example here, the selectbox for model has the key 'model'.
5963
Thus the selected model can be accessed via st.session_state.model.
6064
"""
65+
if "llm_provider" not in st.session_state:
66+
st.session_state.llm_provider = "OpenAI"
6167
if "model" not in st.session_state:
6268
st.session_state.model = get_default_config_value("default_model.gpt")
6369

6470
with st.sidebar:
6571
st.header("Model settings")
72+
provider = st.selectbox(
73+
label="Select provider",
74+
options=["OpenAI", "Ollama"],
75+
key="llm_provider",
76+
)
77+
available_models = []
78+
if provider == "Ollama":
79+
ollama_ready = is_ollama_available()
80+
if not ollama_ready:
81+
st.warning(
82+
"Ollama server not reachable. Ensure it is running locally on port 11434 or set the host via the `OLLAMA_HOST` environment variable."
83+
)
84+
available_models = (
85+
get_ollama_models(model_type="gpts") if ollama_ready else []
86+
)
87+
if not available_models:
88+
st.warning(
89+
"No Ollama models available. Pull a model in your terminal before proceeding."
90+
)
91+
else:
92+
if is_api_key_set():
93+
available_models = get_available_models(
94+
model_type="gpts", api_key=st.session_state.openai_api_key
95+
)
96+
else:
97+
available_models = []
98+
if available_models and st.session_state.model not in available_models:
99+
st.session_state.model = available_models[0]
100+
model_options = (
101+
available_models if available_models else [st.session_state.model]
102+
)
66103
model = st.selectbox(
67104
label="Select a large language model",
68-
options=get_available_models(
69-
model_type="gpts", api_key=st.session_state.openai_api_key
70-
),
105+
options=model_options,
71106
key="model",
72107
help=get_default_config_value("help_texts.model"),
108+
disabled=not available_models,
73109
)
74110
st.slider(
75111
label="Adjust temperature",
@@ -93,7 +129,9 @@ def display_model_settings_sidebar():
93129
st.warning(
94130
"OpenAI generally recommends altering temperature or top_p but not both. See their [API reference](https://platform.openai.com/docs/api-reference/chat/create#chat-create-temperature)"
95131
)
96-
if model != get_default_config_value("default_model.gpt"):
132+
if provider == "OpenAI" and model != get_default_config_value(
133+
"default_model.gpt"
134+
):
97135
st.warning(
98136
""":warning: Be aware of the higher costs and latencies when using more advanced (reasoning) models (like gpt-5). You can see details (incl. costs) about the models and compare them [here](https://platform.openai.com/docs/models/compare)."""
99137
)

modules/youtube.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from youtube_transcript_api import CouldNotRetrieveTranscript, YouTubeTranscriptApi
77
from youtube_transcript_api.formatters import TextFormatter
88

9-
from .helpers import extract_youtube_video_id, get_preffered_languages
9+
from .helpers import extract_youtube_video_id, get_preferred_languages
1010

1111
OEMBED_PROVIDER = "https://noembed.com/embed"
1212

@@ -66,7 +66,7 @@ def fetch_youtube_transcript(url: str):
6666

6767
try:
6868
transcript = YouTubeTranscriptApi().fetch(
69-
video_id, languages=get_preffered_languages()
69+
video_id, languages=get_preferred_languages()
7070
)
7171
except CouldNotRetrieveTranscript as e:
7272
logging.error("Failed to retrieve transcript for URL: %s", str(e))

0 commit comments

Comments
 (0)