Skip to content

Commit b1222bd

Browse files
authored
Merge pull request #122 from sudoleg/develop
Performance improvements
2 parents 169ad30 + df85861 commit b1222bd

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

main.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@
22

33
import streamlit as st
44

5-
from modules.helpers import read_file
6-
from modules.ui import (
7-
display_link_to_repo,
8-
display_nav_menu,
9-
)
5+
from modules.helpers import is_api_key_set, is_api_key_valid, read_file
6+
from modules.ui import display_link_to_repo, display_nav_menu
107

118

129
def main():
@@ -18,6 +15,12 @@ def main():
1815
display_nav_menu()
1916
display_link_to_repo()
2017

18+
if not is_api_key_set():
19+
st.info(
20+
"""It looks like you haven't set the API Key as an environment variable.
21+
Don't worry, you can set it in the sidebar when you go to either one of the pages :)"""
22+
)
23+
2124
st.markdown(body=read_file(".assets/home.md"))
2225

2326

modules/helpers.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ def is_api_key_valid(api_key: str):
2727
Returns:
2828
bool: True if the API key is valid, False if the API key is invalid.
2929
"""
30+
31+
api_key_valid = os.getenv("OPENAI_API_KEY_VALID")
32+
if api_key_valid:
33+
return True
34+
3035
openai.api_key = api_key
3136
try:
3237
openai.models.list()
@@ -44,29 +49,39 @@ def is_api_key_valid(api_key: str):
4449
return False
4550
else:
4651
logging.info("API key validation successful")
52+
os.environ["OPENAI_API_KEY_VALID"] = "yes"
4753
return True
4854

4955

5056
def get_available_models(
5157
model_type: Literal["gpts", "embeddings"], api_key: str = ""
5258
) -> List[str]:
5359
"""
54-
Retrieve a list of available model IDs from OpenAI's API filtered by model type.
60+
Retrieve a filtered list of available model IDs from OpenAI's API or environment variables, based on the specified model type.
5561
5662
Args:
57-
model_type (Literal["gpts", "embeddings"]): The type of models to retrieve.
63+
model_type (Literal["gpts", "embeddings"]): The type of models to retrieve, such as 'gpts' or 'embeddings'.
5864
api_key (str, optional): The API key for authenticating with OpenAI. Defaults to an empty string.
5965
6066
Returns:
61-
List[str]: A list of available model IDs filtered by the specified model type.
62-
Returns an empty list if an authentication error or any other exception occurs.
67+
List[str]: A filtered list of available model IDs matching the specified model type. The list is derived either from the environment variable `AVAILABLE_MODEL_IDS` if set, or from a call to OpenAI's API.
68+
If an authentication error or any other exception occurs during the API call, an empty list is returned.
6369
"""
6470
openai.api_key = api_key
6571
selectable_model_ids = list(
6672
get_default_config_value(f"available_models.{model_type}")
6773
)
74+
75+
# AVAILABLE_MODEL_IDS env var stores all the model IDs available to the user as a list (separated by a comma)
76+
# the env var is set programatically below
77+
available_model_ids = os.getenv("AVAILABLE_MODEL_IDS")
78+
if available_model_ids:
79+
return filter(
80+
lambda m: m in available_model_ids.split(","), selectable_model_ids
81+
)
82+
6883
try:
69-
available_model_ids = [model.id for model in openai.models.list()]
84+
available_model_ids: list = [model.id for model in openai.models.list()]
7085
except openai.AuthenticationError as e:
7186
logging.error(
7287
"An authentication error occurred when fetching available models: %s",
@@ -80,6 +95,9 @@ def get_available_models(
8095
)
8196
return []
8297
else:
98+
# set the AVAILABLE_MODEL_IDS env var, so that the list of available models
99+
# doesn't have to be fetched every time
100+
os.environ["AVAILABLE_MODEL_IDS"] = ",".join(available_model_ids)
83101
return filter(lambda m: m in available_model_ids, selectable_model_ids)
84102

85103

0 commit comments

Comments
 (0)