Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"

[tool.poetry]
name = "together"
version = "1.4.4"
version = "1.4.5"
authors = [
"Together AI <[email protected]>"
]
Expand Down
10 changes: 9 additions & 1 deletion src/together/client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

import os
from typing import Dict
import sys
from typing import Dict, TYPE_CHECKING

from together import resources
from together.constants import BASE_URL, MAX_RETRIES, TIMEOUT_SECS
from together.error import AuthenticationError
from together.types import TogetherClient
from together.utils import enforce_trailing_slash
from together.utils.api_helpers import get_google_colab_secret


class Together:
Expand Down Expand Up @@ -44,6 +46,9 @@ def __init__(
if not api_key:
api_key = os.environ.get("TOGETHER_API_KEY")

if not api_key and "google.colab" in sys.modules:
api_key = get_google_colab_secret("TOGETHER_API_KEY")

if not api_key:
raise AuthenticationError(
"The api_key client option must be set either by passing api_key to the client or by setting the "
Expand Down Expand Up @@ -117,6 +122,9 @@ def __init__(
if not api_key:
api_key = os.environ.get("TOGETHER_API_KEY")

if not api_key and "google.colab" in sys.modules:
api_key = get_google_colab_secret("TOGETHER_API_KEY")

if not api_key:
raise AuthenticationError(
"The api_key client option must be set either by passing api_key to the client or by setting the "
Expand Down
40 changes: 40 additions & 0 deletions src/together/utils/api_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import os
import sys
import platform
from typing import TYPE_CHECKING, Any, Dict

Expand All @@ -12,6 +13,7 @@
import together
from together import error
from together.utils._log import _console_log_level
from together.utils import log_info


def get_headers(
Expand Down Expand Up @@ -82,3 +84,41 @@ def default_api_key(api_key: str | None = None) -> str | None:
return os.environ.get("TOGETHER_API_KEY")

raise error.AuthenticationError(together.constants.MISSING_API_KEY_MESSAGE)


def get_google_colab_secret(secret_name: str = "TOGETHER_API_KEY") -> str | None:
"""
Checks to see if the user is running in Google Colab, and looks for the Together API Key secret.

Args:
secret_name (str, optional). Defaults to TOGETHER_API_KEY

Returns:
str: if the API key is found; None if an error occurred or the secret was not found.
"""
# If running in Google Colab, check for Together in notebook secrets
if "google.colab" in sys.modules:
if TYPE_CHECKING:
from google.colab import userdata # type: ignore
else:
from google.colab import userdata

try:
api_key = userdata.get(secret_name)
if not isinstance(api_key, str):
return None
else:
return str(api_key)
except userdata.NotebookAccessError:
log_info(
"The TOGETHER_API_KEY Colab secret was found, but notebook access is disabled. Please enable notebook "
"access for the secret."
)
except userdata.SecretNotFoundError:
# warn and carry on
log_info("Colab: No Google Colab secret named TOGETHER_API_KEY was found.")

return None

else:
return None
2 changes: 1 addition & 1 deletion tests/integration/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
completion_test_model_list = [
"mistralai/Mistral-7B-v0.1",
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
]
chat_test_model_list = []
embedding_test_model_list = []
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/resources/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def test_max_tokens(
product(
completion_test_model_list,
completion_prompt_list,
[35000, 40000, 50000],
[200000, 400000, 500000],
),
)
def test_high_max_tokens(
Expand Down