diff --git a/pyproject.toml b/pyproject.toml index ff080968..20078bb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api" [tool.poetry] name = "together" -version = "1.4.4" +version = "1.4.5" authors = [ "Together AI " ] diff --git a/src/together/client.py b/src/together/client.py index ea5359a5..cc86dc0f 100644 --- a/src/together/client.py +++ b/src/together/client.py @@ -1,13 +1,15 @@ from __future__ import annotations import os -from typing import Dict +import sys +from typing import Dict, TYPE_CHECKING from together import resources from together.constants import BASE_URL, MAX_RETRIES, TIMEOUT_SECS from together.error import AuthenticationError from together.types import TogetherClient from together.utils import enforce_trailing_slash +from together.utils.api_helpers import get_google_colab_secret class Together: @@ -44,6 +46,9 @@ def __init__( if not api_key: api_key = os.environ.get("TOGETHER_API_KEY") + if not api_key and "google.colab" in sys.modules: + api_key = get_google_colab_secret("TOGETHER_API_KEY") + if not api_key: raise AuthenticationError( "The api_key client option must be set either by passing api_key to the client or by setting the " @@ -117,6 +122,9 @@ def __init__( if not api_key: api_key = os.environ.get("TOGETHER_API_KEY") + if not api_key and "google.colab" in sys.modules: + api_key = get_google_colab_secret("TOGETHER_API_KEY") + if not api_key: raise AuthenticationError( "The api_key client option must be set either by passing api_key to the client or by setting the " diff --git a/src/together/utils/api_helpers.py b/src/together/utils/api_helpers.py index 2ec9d3f9..615c32d8 100644 --- a/src/together/utils/api_helpers.py +++ b/src/together/utils/api_helpers.py @@ -2,6 +2,7 @@ import json import os +import sys import platform from typing import TYPE_CHECKING, Any, Dict @@ -12,6 +13,7 @@ import together from together import error from together.utils._log import _console_log_level +from together.utils import log_info def get_headers( @@ -82,3 +84,41 @@ def default_api_key(api_key: str | None = None) -> str | None: return os.environ.get("TOGETHER_API_KEY") raise error.AuthenticationError(together.constants.MISSING_API_KEY_MESSAGE) + + +def get_google_colab_secret(secret_name: str = "TOGETHER_API_KEY") -> str | None: + """ + Checks to see if the user is running in Google Colab, and looks for the Together API Key secret. + + Args: + secret_name (str, optional). Defaults to TOGETHER_API_KEY + + Returns: + str: if the API key is found; None if an error occurred or the secret was not found. + """ + # If running in Google Colab, check for Together in notebook secrets + if "google.colab" in sys.modules: + if TYPE_CHECKING: + from google.colab import userdata # type: ignore + else: + from google.colab import userdata + + try: + api_key = userdata.get(secret_name) + if not isinstance(api_key, str): + return None + else: + return str(api_key) + except userdata.NotebookAccessError: + log_info( + "The TOGETHER_API_KEY Colab secret was found, but notebook access is disabled. Please enable notebook " + "access for the secret." + ) + except userdata.SecretNotFoundError: + # warn and carry on + log_info("Colab: No Google Colab secret named TOGETHER_API_KEY was found.") + + return None + + else: + return None diff --git a/tests/integration/constants.py b/tests/integration/constants.py index 0d330a2a..1c56030a 100644 --- a/tests/integration/constants.py +++ b/tests/integration/constants.py @@ -1,5 +1,5 @@ completion_test_model_list = [ - "mistralai/Mistral-7B-v0.1", + "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", ] chat_test_model_list = [] embedding_test_model_list = [] diff --git a/tests/integration/resources/test_completion.py b/tests/integration/resources/test_completion.py index 8f436ace..59d1b140 100644 --- a/tests/integration/resources/test_completion.py +++ b/tests/integration/resources/test_completion.py @@ -213,7 +213,7 @@ def test_max_tokens( product( completion_test_model_list, completion_prompt_list, - [35000, 40000, 50000], + [200000, 400000, 500000], ), ) def test_high_max_tokens(