This document provides a comprehensive guide to the safety-tooling repository, covering its purpose, features, installation, and usage of its core components and APIs. It was generated by Gemini, and may contain errors or hallucinations. If in doubt, ccheck the original code.
-
Summary:
safety-toolingis a Python library designed to be shared across many AI Safety projects. Its primary function is to provide a unified and robust LLM (Large Language Model) API interface for interacting with models from various providers, including OpenAI, Anthropic, and Google. The repository aims to standardize common tasks, streamline research workflows, and accelerate the onboarding of new AI Safety researchers. It also serves as a platform for developers to upskill by contributing to a widely-used codebase. -
Target User/Developer: This project is targeted at AI Safety researchers and developers who need to:
- Interact with multiple LLM providers through a common interface.
- Manage API keys, rate limits, and caching efficiently.
- Run and track finetuning experiments.
- Utilize utilities for common research tasks like prompt templating, data processing, and human labeling.
- Unified LLM Inference API: A single interface (
InferenceAPI) for making calls to OpenAI, Anthropic, Google Gemini, HuggingFace, Together AI, OpenRouter, DeepSeek, and local VLLM models. - Batch Inference API: Efficiently process large numbers of prompts using OpenAI and Anthropic batch capabilities.
- Caching Mechanism:
- Automatic caching of LLM API calls to disk (default) or Redis to save costs and time on repeated requests.
- Configurable cache directory and Redis integration.
- Prompt Logging: Optional logging of prompts and responses for debugging and historical reference.
- Rate Limit Management:
- Intelligent handling of OpenAI rate limits, optimizing throughput.
- Customizable concurrency settings for different providers.
- Response Validation: Support for custom validation functions to ensure LLM outputs meet specific criteria, with automatic retries.
- Cost Tracking: Running total of estimated costs for OpenAI models.
- OpenAI Services Support:
- Chat and completion models.
- Moderation API.
- Embedding API.
- Realtime Speech-to-Speech (S2S) API.
- Finetuning Support:
- Scripts for launching and tracking OpenAI and Together AI finetuning jobs.
- Integration with Weights & Biases for experiment logging.
- Text-to-Speech (TTS): Interface for ElevenLabs TTS API.
- API Usage Tracking: Scripts to monitor OpenAI and Anthropic API utilization.
- Experiment Utilities:
- A base configuration class (
ExperimentConfigBase) for standardizing experiment script arguments, logging, and API initialization. - Helper functions for various tasks like plotting, prompt loading (Jinja templating), image/audio processing, and human labeling.
- A base configuration class (
- API Key Management: Centralized API key management via a
.envfile. - VLLM Integration: Utilities to easily deploy models locally with VLLM and integrate them into the
InferenceAPI.
- Primary Programming Language: Python
- Key Frameworks, Libraries, or Significant Dependencies:
- LLM APIs:
openai,anthropic,google-cloud-aiplatform,google-generativeai,together,grayswan-api - TTS:
elevenlabs - Data Handling & Serialization:
pydantic,pandas,jsonlines - Experiment Tracking:
wandb(Weights & Biases) - Utility:
simple-parsing(for CLI arguments),python-dotenv(for .env management),uv(for environment/package management),requests,aiohttp,httpx(for HTTP requests),tiktoken(for token counting),redis,filelock. - Development & Testing:
pytest,ruff,black,pre-commit. - Notebooks:
jupyterlab,ipykernel. - Audio/Image:
pydub,soundfile,librosa,opencv-python.
- LLM APIs:
- Python: Version 3.11 is recommended (as used in CI and pre-commit configurations).
uv: Recommended for Python environment and package management. Install with:curl -LsSf https://astral.sh/uv/install.sh | sh source $HOME/.local/bin/env
- API Keys: You will need API keys for the services you intend to use. These should be stored in a
.envfile (see Configuration section).OPENAI_API_KEY: For OpenAI models.ANTHROPIC_API_KEY: For Anthropic models.ANTHROPIC_BATCH_API_KEY: (Optional) A specific key for Anthropic's batch API if different.GOOGLE_API_KEY: For Google Gemini models viagenaiSDK.TOGETHER_API_KEY: For Together AI models.OPENROUTER_API_KEY: For OpenRouter.ai.HF_TOKEN: For Hugging Face Inference Endpoints or private models.DEEPSEEK_API_KEY: For DeepSeek models.GRAYSWAN_API_KEY: For GraySwan models.ELEVENLABS_API_KEY: For ElevenLabs TTS.GOOGLE_PROJECT_ID&GOOGLE_PROJECT_REGION: For Google Gemini models via Vertex AI.
- Git: For cloning the repository.
- (Optional) Redis: If you want to use Redis for caching. Install Redis from redis.io.
- (Optional) System dependencies for audio processing (if tests or certain features are used):
(This is run in the GitHub Actions CI).
sudo apt-get update sudo apt-get install -y libsndfile1 ffmpeg
There are two main ways to install safety-tooling: from a local clone (recommended for development/research) or directly via pip.
-
Clone the repository:
git clone git@github.com:safety-research/safety-tooling.git cd safety-tooling -
Create and activate a virtual environment using
uv(with Python 3.11):uv venv --python=python3.11 source .venv/bin/activate(If you are not using
uv, create a virtual environment using your preferred method, e.g.,python3.11 -m venv .venv). -
Install the package in editable mode and development dependencies:
uv pip install -e . uv pip install -r requirements_dev.txtThis installs the
safetytoolingpackage itself and additional tools needed for development like linters and testing frameworks. -
(Optional) Install Jupyter kernel: If you plan to use Jupyter notebooks within this environment:
python -m ipykernel install --user --name=safety-tooling-venv
(Replace
safety-tooling-venvwith your preferred kernel display name).
If you don't expect to make changes to safety-tooling and want to use it as a dependency in another project:
pip install git+https://github.com/safety-research/safety-tooling.git@main#egg=safetytoolingReplace main with a specific branch name or tag if needed.
Configuration primarily involves setting up API keys and optional parameters through a .env file.
-
Create
.envfile: Copy the example file to create your local configuration:cp .env.example .env
-
Edit
.envfile: Open the.envfile and add your API keys and any other desired settings. All keys are optional; features relying on a missing key will not work.## All are optional but features that rely on them will not work if they are not set. # required for openai API OPENAI_API_KEY=sk-yourOpenAIkey # required for anthropic API ANTHROPIC_API_KEY=sk-ant-yourAnthropicKey # required for gemini API (Google AI Studio) GOOGLE_API_KEY=yourGoogleAIStudioKey # required for together AI API TOGETHER_API_KEY=yourTogetherAIKey # required for openrouter API OPENROUTER_API_KEY=sk-or-yourOpenRouterKey # required for huggingface inference endpoints HF_TOKEN=hf_yourHuggingFaceToken # required for deepseek API DEEPSEEK_API_KEY=sk_yourDeepSeekKey # required for gray swan API GRAYSWAN_API_KEY=yourGraySwanKey # required for elevenlabs text to speech ELEVENLABS_API_KEY=yourElevenLabsKey # required for gemini via vertex AI GOOGLE_PROJECT_ID=your-gcp-project-id # required for gemini via vertex AI GOOGLE_PROJECT_REGION=your-gcp-project-region ## alternative keys for other orgs (pass openai_tag and anthropic_tag to utils.setup_environment to switch between them) OPENAI_API_KEY1=sk-anotherOpenAIKey OPENAI_API_KEY2=sk-yetAnotherOpenAIKey ANTHROPIC_BATCH_API_KEY=sk-ant-yourAnthropicBatchKey # If you have a separate key for Anthropic batch ## optional for setting up extra logging/caching # print prompts and responses to console (set to "true" to enable) SAFETYTOOLING_PRINT_PROMPTS=false # directory to save prompt history (default: .prompt_history in repo root) PROMPT_HISTORY_DIR=./my_prompt_history # directory to save cached results (default: .cache in repo root) CACHE_DIR=./my_cache # use redis for caching (set to "true" to enable) REDIS_CACHE=false # password for redis (if applicable) REDIS_PASSWORD=yourRedisPassword
Parameter Meanings:
OPENAI_API_KEY,ANTHROPIC_API_KEY, etc.: Your secret API keys for the respective services.OPENAI_API_KEY1,OPENAI_API_KEY2,ANTHROPIC_BATCH_API_KEY: Alternative keys. You can switch to these using theopenai_tagoranthropic_tagparameters insafetytooling.utils.utils.setup_environment()or by passing them directly to theInferenceAPIconstructor. For example,utils.setup_environment(openai_tag="OPENAI_API_KEY1")would make the OpenAI client use the key defined inOPENAI_API_KEY1.SAFETYTOOLING_PRINT_PROMPTS: Iftrue, prompts and responses will be printed to the console duringInferenceAPIcalls. Default isfalse.PROMPT_HISTORY_DIR: Specifies a directory to save a history of prompts and their responses as text files. Useful for debugging. Defaults to.prompt_historyin the repository root if not set, or can be disabled by not setting/setting to empty or by initializingInferenceAPIwithprompt_history_dir=None.CACHE_DIR: Specifies a directory for file-based caching of API responses. Defaults to.cachein the repository root if not set, or can be disabled by settingNO_CACHE=trueor initializingInferenceAPIwithcache_dir=None.REDIS_CACHE: Set totrueto use Redis for caching instead of disk-based caching. Redis must be installed and running. Default isfalse.REDIS_PASSWORD: Password for your Redis instance, if it's password-protected. Default is no password.
-
Loading Environment Variables: The
safetytooling.utils.utils.setup_environment()function is typically called at the beginning of scripts or notebooks. It loads variables from the.envfile into the environment.from safetytooling.utils import utils utils.setup_environment() # Loads default keys # To use an alternative OpenAI key defined as OPENAI_API_KEY1 in .env # utils.setup_environment(openai_tag="OPENAI_API_KEY1")
-
Redis Configuration (Optional): If
REDIS_CACHE=trueis set in your.envfile (or environment):- Host: Defaults to
localhost. - Port: Defaults to
6379. - Password: Set via
REDIS_PASSWORDin.envor environment. Defaults to no authentication. To monitor Redis activity, you can runredis-cliand thenMONITOR.
- Host: Defaults to
This section details the main components of safety-tooling.
-
Description: The
InferenceAPIclass (safetytooling.apis.InferenceAPI) is the central component for interacting with various LLMs. It provides a unified interface to make API calls, manages rate limits, handles caching, and logs prompts/responses. It supports models from OpenAI (chat, completion, moderation, embedding, S2S), Anthropic, Google Gemini (via GenAI and Vertex AI SDKs), HuggingFace Inference Endpoints, Together AI, OpenRouter, DeepSeek, and locally deployed VLLM instances. -
Usage (Library):
-
Import & Initialization:
from pathlib import Path from safetytooling.apis import InferenceAPI from safetytooling.utils import utils # Load API keys from .env and set up environment utils.setup_environment() # Initialize the API # Defaults: cache_dir=".cache", prompt_history_dir=".prompt_history" # To disable caching: cache_dir=None or set NO_CACHE=true in .env # To disable prompt history: prompt_history_dir=None api = InferenceAPI( cache_dir=Path(".my_cache_dir"), prompt_history_dir=Path(".my_prompt_history"), # openai_api_key="sk-...", # Optionally pass keys directly # anthropic_api_key="sk-ant-...", print_prompt_and_response=True # Overrides .env setting )
InferenceAPIConstructor Parameters:anthropic_num_threads(int, default: 80): Max concurrent requests to Anthropic.openai_fraction_rate_limit(float, default: 0.8): Fraction of OpenAI's reported rate limit to utilize.openai_num_threads(int, default: 100): Max concurrent requests to OpenAI (non-S2S).openai_s2s_num_threads(int, default: 40): Max concurrent requests for OpenAI Speech-to-Speech.openai_base_url(str, optional): Custom base URL for OpenAI-compatible APIs.gpt4o_s2s_rpm_cap(int, default: 10): RPM cap for GPT-4o S2S model.gemini_num_threads(int, default: 120): Max concurrent requests to Gemini.gemini_recitation_rate_check_volume(int, default: 100): Number of prompts after which Gemini recitation rate is checked.gemini_recitation_rate_threshold(float, default: 0.5): If Gemini recitation rate exceeds this, the run may pause/error.gray_swan_num_threads(int, default: 80): Max concurrent requests to GraySwan.huggingface_num_threads(int, default: 100): Max concurrent requests to HuggingFace Inference Endpoints.together_num_threads(int, default: 80): Max concurrent requests to Together AI.openrouter_num_threads(int, default: 80): Max concurrent requests to OpenRouter.vllm_num_threads(int, default: 8): Max concurrent requests for local VLLM.deepseek_num_threads(int, default: 20): Max concurrent requests for DeepSeek.prompt_history_dir(Path | "default" | None, default: "default"): Directory for prompt history. "default" usesPROMPT_HISTORY_DIRfrom env or repo root.prompt_history.Nonedisables it.cache_dir(Path | "default" | None, default: "default"): Directory for file-based cache. "default" usesCACHE_DIRfrom env or repo root.cache.Nonedisables file cache (Redis might still be used ifuse_redisis true).use_redis(bool, default: False): Enable Redis caching. Overridden byREDIS_CACHE=truein env.empty_completion_threshold(int, default: 0): Threshold for empty completions (relevant for Gemini).use_gpu_models(bool, default: False): Whether to initialize models that require GPU (currently forBatchModelif implemented).anthropic_api_key(str, optional): Directly pass Anthropic API key.openai_api_key(str, optional): Directly pass OpenAI API key.print_prompt_and_response(bool, default: False): Print prompts/responses to console. OverridesSAFETYTOOLING_PRINT_PROMPTSenv var.use_vllm_if_model_not_found(bool, default: False): If a model ID is not recognized, attempt to use VLLM.vllm_base_url(str, default: "http://localhost:8000/v1/chat/completions"): Base URL for VLLM server.no_cache(bool, default: False): Disable all caching. Overridden byNO_CACHE=truein env.oai_embedding_batch_size(int, default: 2048): Batch size for OpenAI embedding requests.
-
Key Public Methods:
-
async def __call__(...)(Making LLM calls): This is the primary method for generating completions.async def __call__( self, model_id: str, prompt: Prompt | BatchPrompt, audio_out_dir: str | Path = None, print_prompt_and_response: bool = False, n: int = 1, max_attempts_per_api_call: int = 10, num_candidates_per_completion: int = 1, is_valid: Callable[[str], bool] = lambda _: True, insufficient_valids_behaviour: Literal["error", "continue", "pad_invalids", "retry"] = "retry", gemini_use_vertexai: bool = False, huggingface_model_url: str | None = None, force_provider: str | None = None, use_cache: bool = True, **kwargs, # Additional provider-specific parameters (e.g., temperature, max_tokens) ) -> list[LLMResponse]:
- Parameters:
model_id(str): Identifier of the model to use (e.g., "gpt-4o-mini", "claude-3-haiku-20240307").prompt(Prompt | BatchPrompt): APromptobject (orBatchPromptforBatchModel) containing the messages.audio_out_dir(str | Path, optional): Directory to save audio output for Speech-to-Speech models.print_prompt_and_response(bool, default: False): Overrides instance-level setting for this call.n(int, default: 1): Number of completions to generate and return.max_attempts_per_api_call(int, default: 10): Retries if an API call fails.num_candidates_per_completion(int, default: 1): How many raw completions to generate per desiredn. E.g., ifn=1andnum_candidates_per_completion=5, 5 raw completions are generated.is_valid(Callable[[str], bool], default:lambda _: True): A function to validate a completion string. Invalid responses are filtered out.insufficient_valids_behaviour(Literal["error", "continue", "pad_invalids", "retry"], default: "retry"): Behavior if not enough valid completions are found:- "error": Raise RuntimeError.
- "continue": Return the valid responses, even if fewer than
n. - "pad_invalids": Pad with invalid responses up to
n. - "retry": Retry API calls until
nvalid responses are found (respectingmax_attempts_per_api_call).
gemini_use_vertexai(bool, default: False): For Gemini models, use Vertex AI SDK instead of GenAI SDK.huggingface_model_url(str, optional): URL for a HuggingFace Inference Endpoint ifmodel_idis not in the predefined list.force_provider(str, optional): Force use of a specific provider (e.g., "openai", "anthropic"). Useful for new models not yet in predefined lists.use_cache(bool, default: True): Whether to use caching for this specific call. Overridden by instance-levelno_cache=True.**kwargs: Provider-specific parameters liketemperature,max_tokens,top_p,logprobs,seed,stop_sequences(for Anthropic, usestopfor OpenAI),response_format(for OpenAI).
- Returns:
list[LLMResponse]: A list ofLLMResponseobjects. - Example:
from safetytooling.data_models import ChatMessage, MessageRole, Prompt prompt_obj = Prompt(messages=[ChatMessage(content="What is your name?", role=MessageRole.user)]) responses = await api( model_id="gpt-4o-mini", prompt=prompt_obj, temperature=0.7, max_tokens=50, n=1 ) if responses: print(responses[0].completion)
- Parameters:
-
async def ask_single_question(...): A convenience wrapper around__call__for simple user questions.async def ask_single_question( self, model_id: str, question: str, system_prompt: str | None = None, **api_kwargs, # Passed to __call__ ) -> list[str]:
- Parameters:
model_id(str): Model identifier.question(str): The user's question.system_prompt(str, optional): An optional system message.**api_kwargs: Additional arguments for the main__call__method (e.g.,n,temperature).
- Returns:
list[str]: A list of completion strings. - Example:
answers = await api.ask_single_question( model_id="claude-3-opus-20240229", question="What is the capital of France?", n=1 ) print(answers[0])
- Parameters:
-
async def moderate(...)(OpenAI Moderation): Checks text against OpenAI's moderation API.async def moderate( self, texts: list[str], model_id: str = "text-moderation-latest", progress_bar: bool = False, ) -> list[TaggedModeration] | None:
- Parameters:
texts(list[str]): A list of text strings to moderate.model_id(str, default: "text-moderation-latest"): Moderation model to use.progress_bar(bool, default: False): Show a progress bar for batch processing.
- Returns:
list[TaggedModeration] | None: A list ofTaggedModerationobjects, orNoneif caching fails and API call also fails. Each object contains the moderation results for the corresponding input text. - Example:
moderation_results = await api.moderate(["This is a test.", "This is another test."]) if moderation_results: for result in moderation_results: print(f"Text: {result.moderation.input_text if hasattr(result.moderation, 'input_text') else 'N/A (older API version)'}, Flagged: {result.moderation.flagged}")
- Parameters:
-
async def embed(...)(OpenAI Embeddings): Generates embeddings for a list of texts using OpenAI.async def embed( self, texts: list[str], model_id: str = "text-embedding-3-large", dimensions: int | None = None, progress_bar: bool = False, ) -> np.ndarray:
- Parameters:
texts(list[str]): List of texts to embed.model_id(str, default: "text-embedding-3-large"): Embedding model to use. Supported: "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002".dimensions(int, optional): Desired dimensionality for new embedding models (e.g., "text-embedding-3-small", "text-embedding-3-large").progress_bar(bool, default: False): Show a progress bar.
- Returns:
np.ndarray: A NumPy array where each row is the embedding for the corresponding text. - Example:
import numpy as np embeddings_array = await api.embed( ["Hello world", "Another sentence"], model_id="text-embedding-3-small", dimensions=256 ) print(embeddings_array.shape)
- Parameters:
-
-
Caching:
- File-based: By default, API responses are cached in the
cache_dir(default~/.cacheorrepo_root/.cache). Cache files are organized by a hash of the model parameters and then binned by a hash of the prompt. - Redis: If
use_redis=True(orREDIS_CACHE=truein.env), Redis will be used. Configure Redis connection via environment variables (REDIS_PASSWORD, defaults tolocalhost:6379). - Disabling Cache:
- Set
NO_CACHE=truein your.envfile or as an environment variable. - Initialize
InferenceAPI(cache_dir=None, no_cache=True). - Pass
use_cache=Falseto an individualapi()call.
- Set
- Cache keys are generated based on the
Promptobject andLLMParams(model ID, temperature, max_tokens, etc.).
- File-based: By default, API responses are cached in the
-
Prompt Logging: If
prompt_history_diris enabled (default~/.prompt_historyorrepo_root/.prompt_history), human-readable.txtfiles containing the prompt and response(s) are saved, timestamped for easy reference. This can be disabled by settingprompt_history_dir=NoneduringInferenceAPIinitialization. -
Rate Limits & Concurrency: The API manages rate limits, especially for OpenAI, by respecting headers and distributing requests. Concurrency for each provider can be set during
InferenceAPIinitialization (e.g.,openai_num_threads,anthropic_num_threads). -
Using OpenAI-Compatible APIs (e.g., OpenRouter, DeepSeek): You can use
InferenceAPIwith any LLM provider that offers an OpenAI-compatible API endpoint.# Example for OpenRouter openrouter_api_key = os.getenv("OPENROUTER_API_KEY") # Loaded by utils.setup_environment() api_openrouter = InferenceAPI( openai_base_url="https://openrouter.ai/api/v1", openai_api_key=openrouter_api_key # Pass the specific key for this endpoint ) response = await api_openrouter( model_id="mistralai/mistral-7b-instruct", # OpenRouter model string prompt=Prompt(messages=[ChatMessage(role="user", content="Tell me a joke about AI.")]), force_provider="openai" # Crucial: tells InferenceAPI to use the OpenAI client logic ) print(response[0].completion) # Example for DeepSeek (already has dedicated handling but illustrates the principle) # deepseek_api_key = os.getenv("DEEPSEEK_API_KEY") # api_deepseek = InferenceAPI( # openai_base_url="https://api.deepseek.com", # openai_api_key=deepseek_api_key # ) # response = await api_deepseek( # model_id="deepseek-chat", # prompt=Prompt(messages=[ChatMessage(role="user", content="What is Deep Learning?")]), # force_provider="openai" # ) # print(response[0].completion)
When
force_provider="openai"is used,InferenceAPIwill use its OpenAI client logic, but direct requests to theopenai_base_urlyou provided.model_idshould be the specific model string recognized by that custom endpoint. -
VLLM Integration (Local Model Deployment):
safety-toolingprovides utilities to deploy models locally using VLLM and query them viaInferenceAPI.from safetytooling.utils.vllm_utils import deploy_model_vllm_locally_auto from safetytooling.data_models import ChatMessage, MessageRole, Prompt # This will download the model (if not cached) and start a VLLM server process # Ensure you have vllm installed: uv pip install vllm # This function returns a VLLMDeployment object which holds the server process. # It's important to properly manage this object, e.g., by calling server.close() # when done, or using it in a context manager if one were provided. # The `deploy_model_vllm_locally_auto` function itself will register cleanup # handlers to try and terminate the server on script exit. try: server = await deploy_model_vllm_locally_auto( "meta-llama/Llama-3.1-8B-Instruct", # Or any HuggingFace model ID compatible with VLLM max_model_len=1024, max_num_seqs=32 ) # Initialize InferenceAPI to use the local VLLM server vllm_api = InferenceAPI( vllm_base_url=f"{server.base_url}/v1/chat/completions", # Standard VLLM OpenAI-compatible endpoint vllm_num_threads=32, use_vllm_if_model_not_found=True # Good for when model_id is the HF path ) prompt_vllm = Prompt(messages=[ChatMessage(content="What is your name?", role=MessageRole.user)]) response_vllm = await vllm_api( model_id=server.model_name, # Use the model name VLLM is serving prompt=prompt_vllm, print_prompt_and_response=True, temperature=0.1 ) print(f"VLLM Response: {response_vllm[0].completion}") finally: if 'server' in locals() and server: print("Closing VLLM server...") server.close()
The
deploy_model_vllm_locally_autofunction handles determining if the model is a base model or has LoRA adapters. For more control, usedeploy_model_vllm_locallyand specifybase_modelandlora_adaptersmanually.
-
-
Description: The
BatchInferenceAPIclass (safetytooling.apis.BatchInferenceAPI) provides a wrapper around the batch processing capabilities of Anthropic and OpenAI. This is useful for processing a large number of prompts more efficiently and often at a lower cost than individual API calls, especially when latency per prompt is less critical than overall throughput. -
Usage (Library):
-
Import & Initialization:
from pathlib import Path from safetytooling.apis.batch_api import BatchInferenceAPI # Note: different import path from safetytooling.utils import utils # Load API keys from .env. # For Anthropic batch, ensure ANTHROPIC_BATCH_API_KEY is set if different, # or pass anthropic_tag="ANTHROPIC_BATCH_API_KEY" to setup_environment. utils.setup_environment() # Initialize the BatchInferenceAPI batch_api = BatchInferenceAPI( log_dir=Path(".my_batch_logs"), # Directory to store input files for OpenAI batch cache_dir=Path(".my_cache_dir"), # Uses the same caching mechanism # openai_api_key="sk-...", # anthropic_api_key="sk-ant-yourAnthropicBatchKey", # Or pass the specific batch key no_cache=False )
BatchInferenceAPIConstructor Parameters:log_dir(Path | "default", default: "default"): Directory to store temporary files, particularly input files for OpenAI batch API. If "default", it usesprompt_history_dir(from env or~/.prompt_history).prompt_history_dir(Path | "default" | None, default: "default"): SeeInferenceAPI.cache_dir(Path | "default" | None, default: "default"): SeeInferenceAPI.use_redis(bool, default: False): SeeInferenceAPI.anthropic_api_key(str, optional): Directly pass Anthropic API key (ideally batch-specific one).openai_api_key(str, optional): Directly pass OpenAI API key.no_cache(bool, default: False): Disable all caching.
-
Key Public Methods:
async def __call__(...)(Making Batch LLM calls): Submits a list of prompts for batch processing.async def __call__( self, model_id: str, prompts: list[Prompt], log_dir: Path | None = None, # Overrides instance-level log_dir for this call use_cache: bool = True, max_tokens: int | None = None, chunk: int | None = None, # Max prompts per batch request (provider-dependent limits apply) **kwargs, # Additional provider-specific parameters (e.g., temperature) ) -> tuple[list[LLMResponse | None], str]:
- Parameters:
model_id(str): Identifier of the model (must be a supported Anthropic or OpenAI model for batch).prompts(list[Prompt]): A list ofPromptobjects.log_dir(Path, optional): Directory to store intermediate files for this batch job (e.g., OpenAI input file).use_cache(bool, default: True): Whether to use caching.max_tokens(int, optional): Max tokens per completion. Required for Anthropic batch.chunk(int, optional): Maximum number of prompts to send in a single underlying batch request.- Anthropic: Default 100,000 prompts or ~250MB total input size.
- OpenAI: Default 50,000 prompts per file.
The
BatchInferenceAPIwill further chunk based on these limits if your inputpromptslist is larger.
**kwargs: Provider-specific parameters (e.g.,temperature).seedis used for caching.
- Returns:
tuple[list[LLMResponse | None], str]:- A list of
LLMResponseobjects, one for each input prompt. An element can beNoneif a particular prompt failed within the batch. - A string representing the batch ID(s). If chunking occurred, this will be a comma-separated string of multiple batch IDs. If all responses were cached, it will be "cached".
- A list of
- Example (Anthropic Batch):
from safetytooling.data_models import ChatMessage, MessageRole, Prompt import asyncio # In an async function: prompts_list = [ Prompt(messages=[ChatMessage(content="Tell me about dogs.", role=MessageRole.user)]), Prompt(messages=[ChatMessage(content="Tell me about cats.", role=MessageRole.user)]) ] responses, batch_id = await batch_api( model_id="claude-3-haiku-20240307", # Or other supported Claude model prompts=prompts_list, max_tokens=100, # Required for Anthropic temperature=0.5 ) print(f"Anthropic Batch ID: {batch_id}") for resp in responses: if resp: print(resp.completion)
- Example (OpenAI Batch):
from safetytooling.data_models import ChatMessage, MessageRole, Prompt import asyncio # In an async function: prompts_list = [ Prompt(messages=[ChatMessage(content="What is Python?", role=MessageRole.user)]), Prompt(messages=[ChatMessage(content="What is JavaScript?", role=MessageRole.user)]) ] # For OpenAI, log_dir is where the input .jsonl file will be created. # Ensure this directory exists or is writable. openai_batch_log_dir = Path("./openai_batch_temp_files") openai_batch_log_dir.mkdir(exist_ok=True) responses, batch_id = await batch_api( model_id="gpt-4o-mini", # Or other supported OpenAI model prompts=prompts_list, log_dir=openai_batch_log_dir, max_tokens=150, temperature=0.6 ) print(f"OpenAI Batch ID: {batch_id}") for resp in responses: if resp: print(resp.completion)
- Parameters:
- Caching: The
BatchInferenceAPIuses the same caching mechanism asInferenceAPI. Individual prompts within a batch are checked against the cache. If all prompts in a call are cached, "cached" is returned as the batch ID. If some are cached and others are not, only the uncached prompts are sent to the batch provider. - Anthropic Batch Specifics:
- The
max_tokensparameter is required for Anthropic batch calls. - Anthropic batch API has limits on the number of requests (default 100,000) and total input file size (around 250MB). The
chunk_prompts_for_anthropicutility function (used internally) splits prompts accordingly. - See
examples/anthropic_batch_api/batch_api.ipynbandrun_anthropic_batch.pyfor detailed examples.
- The
- OpenAI Batch Specifics:
- The API prepares a JSONL file with your prompts, uploads it to OpenAI, creates a batch job, polls for completion, and then retrieves results.
- The
log_dirparameter is important as it's where the input JSONL file is temporarily stored. - OpenAI batch API has a limit of 50,000 prompts per file.
- See
examples/open_ai_batch_api/batch_api.ipynbfor a detailed example.
-
-
Description: The
safetytooling.apis.finetuning.openai.runmodule allows you to launch OpenAI finetuning jobs. It handles data validation, file uploads to OpenAI, job creation, and syncing metadata and results with Weights & Biases. -
Usage (Script): Run the script from your terminal:
python -m safetytooling.apis.finetuning.openai.run \ --model 'gpt-3.5-turbo-1106' \ --train_file path/to/your/train_data.jsonl \ --val_file path/to/your/validation_data.jsonl \ --n_epochs 1 \ --learning_rate_multiplier 2.0 \ --batch_size "auto" \ --wandb_project_name "my-finetuning-project" \ --wandb_entity "my-wandb-username" \ --tags "experiment1" "dataset_v2" \ --save_folder "output_finetunes/" \ --save_config TrueCommand-line Arguments (
OpenAIFTConfig):--train_file(Path, required): Path to the training data JSONL file.--model(str, required): The base model to fine-tune (e.g., 'gpt-3.5-turbo-1106', 'babbage-002', 'gpt-4o-mini-2024-07-18').--val_file(Path, optional): Path to the validation data JSONL file.--n_epochs(int, default: 1): Number of training epochs.--learning_rate_multiplier(float | "auto", default: "auto"): Multiplier for the default learning rate.--batch_size(int | "auto", default: "auto"): Batch size. "auto" lets OpenAI decide.--beta(float | "auto", default: "auto"): Beta hyperparameter for DPO fine-tuning.--method(Literal["supervised", "dpo"], default: "supervised"): Fine-tuning method.--dry_run(bool, default: False): If set, validates data and estimates cost without launching the job.--logging_level(str, default: "info"): Logging level for the script.--openai_tag(str, default: "OPENAI_API_KEY1"): Which OpenAI API key from.envto use (e.g.,OPENAI_API_KEY,OPENAI_API_KEY1).--wandb_project_name(str, optional): Weights & Biases project name.--wandb_entity(str, optional): Weights & Biases entity (username or team).--tags(tuple[str, ...], default: ("test",)): Tags for the WandB run.--save_config(bool, default: False): If true, saves the final fine-tuning job configuration tosave_folder.--save_folder(str, optional): Directory to save job configuration. If it ends with/, appendsmodel_name/train_file_name[id]fine_tuned_model_id.
Data Format: Training and validation files must be in JSONL format, where each line is a JSON object representing a single training example. For chat models, this typically looks like:
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Who won the world series in 2020?"}, {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}]}For DPO, the format is different (see OpenAI documentation). The script
safetytooling.apis.finetuning.openai.check.openai_check_finetuning_datavalidates the data and estimates costs.Deleting OpenAI Files: The notebook
safetytooling/apis/finetuning/openai/delete_files.ipynbcontains utility functions to list and delete files uploaded to your OpenAI account, which can be useful for managing storage used by fine-tuning datasets and results.
-
Description: The
safetytooling.apis.finetuning.together.runmodule allows you to launch finetuning jobs on Together AI. It handles file uploads, job creation, and can sync metadata with Weights & Biases. It also supports saving the trained model weights. -
Usage (Script): Run the script from your terminal:
python -m safetytooling.apis.finetuning.together.run \ --model 'meta-llama/Llama-3-8B-Instruct-Turbo' \ --train_file path/to/your/train_data.jsonl \ --n_epochs 3 \ --learning_rate 1e-5 \ --lora True \ --lora_r 16 \ --suffix "my-custom-model" \ --wandb_project_name "my-together-finetuning" \ --save_folder "output_together_finetunes/" \ --save_model True \ --save_config TrueCommand-line Arguments (
TogetherFTConfig):--train_file(Path, required): Path to the training data JSONL file.--model(str, required): The base model on Together AI to fine-tune (e.g., 'meta-llama/Llama-3-8B-Instruct-Turbo').--val_file(Path, optional): Path to the validation data JSONL file.--n_epochs(int, default: 1): Number of training epochs.--batch_size(int | "auto", default: 16): Batch size.--learning_rate(float, default: 1e-5): Learning rate.--lora(bool, default: False): Whether to use LoRA for fine-tuning.--lora_r(int, default: 8): LoRA rank.--lora_alpha(int, default: 8): LoRA alpha.--lora_dropout(float, default: 0.0): LoRA dropout.--suffix(str, default: ""): Suffix to append to the fine-tuned model name on Together AI.--dry_run(bool, default: False): (Note: dry_run is part ofFinetuneConfigbase, but its specific effect on Together AI might vary or not be fully implemented as cost estimation is OpenAI specific).--logging_level(str, default: "info"): Logging level.--openai_tag: (Not directly used by Together AI, inherited from base).--wandb_project_name(str, optional): Weights & Biases project name.--wandb_entity(str, optional): Weights & Biases entity.--tags(tuple[str, ...], default: ("test",)): Tags for WandB run.--save_config(bool, default: False): Save job config.--save_model(bool, default: False): Download and save the fine-tuned model weights.--save_folder(str, optional): Directory to save job config and model weights.
Data Format: Follow Together AI's required JSONL format for fine-tuning (typically similar to OpenAI's chat format).
Hugging Face Hub Handler Example: The file
safetytooling/apis/finetuning/together/huggingface_hub_example/handler.mdprovides an examplehandler.pyscript. This script can be used when deploying a LoRA model (fine-tuned via Together AI and weights downloaded) as a Hugging Face Inference Endpoint. It shows how to load a base model and apply the LoRA adapter.
These scripts help monitor your API consumption.
-
OpenAI API Usage (
safetytooling.apis.inference.usage.usage_openai)- Description: Fetches and displays the current rate limit usage for specified OpenAI models and API keys.
- Command:
python -m safetytooling.apis.inference.usage.usage_openai \ --models 'gpt-4o-mini' 'gpt-3.5-turbo-0125' \ --openai_tags 'OPENAI_API_KEY1' 'OPENAI_API_KEY2' - Arguments:
--models(list[str]): A list of model IDs to check usage for. Default:['gpt-3.5-turbo-0125', 'gpt-3.5-turbo-1106', 'gpt-4-1106-preview', 'o1', 'o1-2024-12-17', 'o3-mini', 'o3-mini-2025-01-31']--openai_tags(list[str]): A list of.envvariable names for the OpenAI API keys to check. Default:['OPENAI_API_KEY1', 'OPENAI_API_KEY2']
- Output: Prints usage (0.0 to 1.0, where 1.0 means hitting rate limits) for token and request limits. -1 indicates an error fetching data.
-
Anthropic API Usage (
safetytooling.apis.inference.usage.usage_anthropic)- Description: Estimates how many more concurrent requests the Anthropic API can handle for your key before hitting rate limits.
- Command:
python -m safetytooling.apis.inference.usage.usage_anthropic
- Arguments: None (uses
ANTHROPIC_API_KEYfrom environment). - Output: Prints an estimate of available concurrent request capacity.
-
Description: The
safetytooling.apis.tts.elevenlabsmodule provides functions to convert text to speech using the ElevenLabs API. -
Usage (Library):
-
Import:
from safetytooling.apis.tts.elevenlabs import generate_tts_audio_from_text, generate_tts_audio_from_dataframe from pathlib import Path import pandas as pd import asyncio # For running async functions
-
async def generate_tts_audio_from_text(...): Generates TTS for a single string.async def generate_tts_audio_from_text( text: str, output_dir: Union[str, Path], voice: str = "Rachel", # Default voice name model: str = "eleven_multilingual_v2", client: AsyncElevenLabs = None, # Optional pre-initialized client convert_to_wav: bool = False, # If True, converts MP3 output to WAV name_override: Optional[str] = None, # Custom filename (without extension) ) -> str: # Returns path to the audio file
- Example:
The
# In an async context: output_path = await generate_tts_audio_from_text( text="Hello, this is a test of the ElevenLabs API.", output_dir=Path("./tts_output"), voice="Rachel", # Or any other valid voice name/ID convert_to_wav=True ) print(f"Audio saved to: {output_path}")
voiceparameter can be a voice name (e.g., "Rachel", "Adam") or a voice ID from ElevenLabs. The script uses internal mappings for common names, but direct IDs also work. Default model iseleven_multilingual_v2.
- Example:
-
async def generate_tts_audio_from_dataframe(...): Generates TTS for texts in a Pandas DataFrame column.async def generate_tts_audio_from_dataframe( df: pd.DataFrame, output_dir: Union[str, Path], transcription_col: str = "behavior", # Column in df containing text voice: str = "Rachel", model: str = "eleven_multilingual_v2", convert_to_wav: bool = False, ) -> List[str]: # Returns list of paths to generated audio files
- Example:
# Assuming df is a Pandas DataFrame with a column "my_texts" # In an async context: # data = {'my_texts': ["First sentence.", "Second sentence for TTS."]} # df = pd.DataFrame(data) # # audio_files = await generate_tts_audio_from_dataframe( # df=df, # output_dir=Path("./tts_output_df"), # transcription_col="my_texts", # convert_to_wav=True # ) # df["audio_file_paths"] = audio_files # print(df)
- Example:
-
The safetytooling.data_models package defines Pydantic models for consistent data structures throughout the library.
-
Prompt(safetytooling.data_models.messages.Prompt)- Represents an LLM prompt.
- Contains a sequence of
ChatMessageobjects. - Example:
Prompt(messages=[ChatMessage(role=MessageRole.user, content="Hello!")])
-
ChatMessage(safetytooling.data_models.messages.ChatMessage)- Represents a single message in a conversation.
- Attributes:
role(MessageRole): The role of the message sender (e.g.,user,system,assistant,audio,image).content(str | Path): The content of the message. Can be text or a path to an image/audio file.
- The
MessageRoleenum defines valid roles.
-
LLMResponse(safetytooling.data_models.inference.LLMResponse)- Represents the output from an LLM call.
- Key Attributes:
model_id(str): Identifier of the model that generated the response.completion(str): The generated text.stop_reason(StopReason | GeminiStopReason | GeminiBlockReason): Reason why the model stopped generating.cost(float): Estimated cost of the API call (primarily for OpenAI).audio_out(str | Path, optional): Path to output audio file for S2S models.logprobs(list[dict[str, float]], optional): Log probabilities for tokens, if requested and supported.safety_ratings(dict | SafetyRatings, optional): Safety ratings from Gemini models.
-
LLMParams(safetytooling.data_models.inference.LLMParams)- Represents parameters for an LLM call, used for caching.
- Includes
model,temperature,max_tokens, etc.
-
EmbeddingParams&EmbeddingResponseBase64(safetytooling.data_models.embedding)- Used for embedding generation requests and responses.
-
HashableBaseModel(safetytooling.data_models.hashable.HashableBaseModel)- A base Pydantic model that provides a
model_hash()method, used for generating cache keys. Most data models inherit from this.
- A base Pydantic model that provides a
A collection of helper modules and functions.
-
utils.setup_environment()(safetytooling.utils.utils.setup_environment)- Description: Loads API keys from the
.envfile into environment variables and configures logging. Crucial to call at the start of scripts. - Parameters:
logging_level(str, default: "info"): Sets the root logging level (e.g., "debug", "info", "warning").openai_tag(str, default: "OPENAI_API_KEY"): The variable name in.envto load as the primaryOPENAI_API_KEY.anthropic_tag(str, default: "ANTHROPIC_API_KEY"): The variable name in.envto load as the primaryANTHROPIC_API_KEY.
- Usage:
from safetytooling.utils import utils utils.setup_environment(openai_tag="OPENAI_API_KEY1", logging_level="debug")
- Description: Loads API keys from the
-
ExperimentConfigBase(safetytooling.utils.experiment_utils.ExperimentConfigBase)- Description: A
dataclasses.dataclassintended as a base class for experiment configuration. It standardizes common parameters (likeoutput_dir, API settings) and provides anapiproperty for easyInferenceAPIaccess and asetup_experiment()method for boilerplate setup (directory creation, logging, seeding). - Key Attributes (to be set by subclass or via CLI with
simple_parsing):output_dir,cache_dir,prompt_history_dir, API thread counts, etc. - Key Methods/Properties:
api(property): Returns an initializedInferenceAPIinstance based on the config.setup_experiment(log_file_prefix: str): Sets up logging to a file inoutput_dir/logs, sets random seeds, and callsutils.setup_environment().
- Example (from
examples/anthropic_batch_api/run_anthropic_batch.py):import dataclasses from simple_parsing import ArgumentParser from safetytooling.utils.experiment_utils import ExperimentConfigBase import asyncio @dataclasses.dataclass class MyExperimentConfig(ExperimentConfigBase): model_id: str = "claude-3-opus-20240229" custom_param: int = 10 async def run_my_experiment(cfg: MyExperimentConfig): # cfg.api is now available and configured response = await cfg.api.ask_single_question( model_id=cfg.model_id, question="Hello!" ) print(response[0]) cfg.log_api_cost(metadata={"experiment_stage": "initial_call"}) # Logs cost if __name__ == "__main__": parser = ArgumentParser() parser.add_arguments(MyExperimentConfig, dest="exp_config") args = parser.parse_args() cfg: MyExperimentConfig = args.exp_config cfg.setup_experiment(log_file_prefix="my_exp_run") # Sets up logging, seeds, etc. asyncio.run(run_my_experiment(cfg))
- Description: A
-
Prompt Utilities (
safetytooling.utils.prompt_utils)- Provides functions for loading and rendering prompts using Jinja2 templates.
get_prompt_jinja_env(prompt_dir: str): Gets/creates a Jinja environment.get_prompt_template(template_path: str): Loads a specific template.- Example: If you have
prompts/my_template.jinjawithHello {{ name }}!:from safetytooling.utils.prompt_utils import get_prompt_template template = get_prompt_template("my_template.jinja") rendered_prompt = template.render(name="World") # rendered_prompt == "Hello World!"
-
VLLM Utilities (
safetytooling.utils.vllm_utils)deploy_model_vllm_locally_auto(...): Deploys a HuggingFace model locally using a VLLM server. Manages downloading the model and starting the server process. Returns aVLLMDeploymentobject.deploy_model_vllm_locally(...): Similar to auto, but requires manual specification ofbase_modelandlora_adapters.- The
VLLMDeploymentobject should be closed (deployment.close()) when done to terminate the server. The module attempts to clean up deployments on exit.
-
HuggingFace Utilities (
safetytooling.utils.hf_utils)get_adapter_config(adapter_id: str, token: str = None) -> dict: Downloads and parsesadapter_config.jsonfor a LoRA adapter.group_models_or_adapters_by_model(models: List[str], token: str = None) -> Dict[str, List[str]]: Groups a list of model IDs into base models and their associated LoRA adapters.resolve_lora_model(model: str) -> tuple[str, str | None]: Given a model ID, returns its base model and LoRA adapter ID (if it's an adapter).get_lora_rank(adapter_id: str, token: str = None) -> int: Gets the LoRA rankrfrom an adapter's config.push_model_to_hf(...): A Python wrapper for thescripts/push_together_model_to_hf.shscript, used to push models (e.g., fine-tuned with Together AI) to the Hugging Face Hub.
-
Other Utilities:
plot_utils.py: Functions for plotting, e.g., confusion matrices (plot_confusion_matrix) and interactive data atlases (plot_data_atlas).image_utils.py: Image processing functions like adding text to images (add_text_to_image), converting to base64 (image_to_base64).audio_utils.py: Audio processing functions, including a customWAVFileclass, SoX resampling, and audio format preparation for various APIs.human_labeling_utils.py: A framework for creating human labeling interfaces in Jupyter notebooks, storing labels persistently.jailbreak_metrics.py: ContainsJailbreakMetricsclass for running inference over datasets and evaluating using classifiers or logprobs, often used for safety evaluations.text_utils.py: Text normalization and functions for generating/handling tokenized attack strings.caching_utils.py: Provides afile_cachedecorator for general-purpose function output caching to files.
- Description: A shell script located at
scripts/push_together_model_to_hf.shdesigned to download model weights from a Together AI fine-tuning job and push them to a new or existing Hugging Face Hub repository as LoRA adapters. - Usage:
./scripts/push_together_model_to_hf.sh <model_name_on_hf> <together_ft_id> <local_experiment_dir> <hf_username> [setup_flag]
- Arguments:
model_name_on_hf(str): The desired repository name on Hugging Face (e.g.,my-cool-lora-model).together_ft_id(str): The ID of your fine-tuning job on Together AI (e.g.,ft-12345678-abcd).local_experiment_dir(str): A local directory path where the script will download files and set up the Git repository.hf_username(str): Your Hugging Face username.setup_flag(bool, optional, default:False): IfTrue, the script will attempt to installzstd,git-lfs, rungit lfs install, and attempthuggingface-cli login.
- Prerequisites for the script:
HF_TOKENandTOGETHER_API_KEYenvironment variables must be set. The script attempts to source them fromsafety-tooling/SECRETSif not found in the environment.huggingface-cliandtogether-cliinstalled and configured.- Git and Git LFS installed.
- SSH key configured with your Hugging Face account for Git push operations.
- Python Wrapper:
safetytooling.utils.hf_utils.push_model_to_hf(...)provides a Python interface to run this script.
safety-tooling/
├── .env.example # Example environment file for API keys and configs
├── .github/workflows/ # GitHub Actions CI configuration
│ └── lint_and_test.yml
├── .gitignore # Files and directories to ignore in Git
├── .pre-commit-config.yaml # Configuration for pre-commit hooks (linting/formatting)
├── LICENSE # Project license (MIT)
├── Makefile # Makefile for common tasks (e.g., installing pre-commit hooks)
├── README.md # Main README with overview and setup
├── examples/ # Example scripts and notebooks demonstrating usage
│ ├── anthropic_batch_api/ # Anthropic Batch API examples
│ ├── inference_api/ # InferenceAPI examples (including VLLM)
│ └── open_ai_batch_api/ # OpenAI Batch API examples
├── pyproject.toml # Project metadata, dependencies, and tool configurations (ruff, black, pytest)
├── requirements_dev.txt # Development-specific Python dependencies
├── safetytooling/ # Main source code for the library
│ ├── apis/ # Core API interaction logic
│ │ ├── __init__.py
│ │ ├── batch_api.py # BatchInferenceAPI class
│ │ ├── finetuning/ # Finetuning scripts and utilities
│ │ │ ├── openai/ # OpenAI finetuning
│ │ │ └── together/ # Together AI finetuning
│ │ ├── inference/ # LLM inference logic
│ │ │ ├── __init__.py
│ │ │ ├── api.py # InferenceAPI class (main entry point)
│ │ │ ├── cache_manager.py # Caching logic (file-based, Redis)
│ │ │ ├── model.py # Base InferenceAPIModel protocol
│ │ │ ├── anthropic.py # Anthropic client
│ │ │ ├── openai/ # OpenAI client (chat, completion, moderation, etc.)
│ │ │ ├── gemini/ # Google Gemini client (genai, vertexai)
│ │ │ ├── ... (other provider clients: huggingface, together, openrouter, etc.)
│ │ │ └── usage/ # API usage tracking scripts
│ │ ├── tts/ # Text-to-speech APIs
│ │ │ └── elevenlabs.py # ElevenLabs client
│ │ └── utils.py # API-specific utilities
│ ├── data_models/ # Pydantic models for data structures (Prompt, LLMResponse, etc.)
│ └── utils/ # General utility functions and classes
│ ├── experiment_utils.py # ExperimentConfigBase and related helpers
│ ├── audio_utils.py
│ ├── image_utils.py
│ ├── prompt_utils.py # Jinja templating for prompts
│ ├── vllm_utils.py # VLLM deployment helpers
│ └── ... (other utils: caching, hf, human_labeling, math, plot, text)
├── scripts/ # Utility shell scripts
│ └── push_together_model_to_hf.sh # Script to push Together fine-tuned models to HF Hub
└── tests/ # Pytest unit and integration tests
├── test_api.py
├── test_api_cache.py
├── test_batch_api.py
└── test_other_apis.py
Key File Locations:
- Configuration:
.env.example(template for.env). - Main Entry Point for Inference:
safetytooling/apis/inference/api.py(definesInferenceAPI). - Main Entry Point for Batch Inference:
safetytooling/apis/batch_api.py(definesBatchInferenceAPI). - Core Data Structures:
safetytooling/data_models/. - OpenAI Finetuning Script:
safetytooling/apis/finetuning/openai/run.py. - Together Finetuning Script:
safetytooling/apis/finetuning/together/run.py. - General Utilities & Experiment Setup:
safetytooling/utils/.
Tests are written using pytest.
-
Ensure all dependencies (including dev) are installed:
uv pip install -e . uv pip install -r requirements_dev.txt -
Set necessary API keys in your
.envfile or environment. Some tests make live API calls. Key for tests:OPENAI_API_KEY,ANTHROPIC_API_KEY. -
Run tests:
python -m pytest -v -s -n 6
-v: Verbose output.-s: Show print statements (output from tests).-n 6: Run tests in parallel with 6 workers (adjust as needed).
-
Running Slow Tests (e.g., Batch API tests): Some tests are marked as slow and are skipped by default to speed up CI. To run them, set the
SAFETYTOOLING_SLOW_TESTSenvironment variable totrue:SAFETYTOOLING_SLOW_TESTS=True python -m pytest -v -s -n 6
- Linting: Ruff is used for linting.
- Formatting: Black is used for code formatting.
- Pre-commit Hooks: The project uses pre-commit hooks to automatically run Ruff and Black before commits.
- To install and enable pre-commit hooks:
This will run
make hooks
pre-commit install ...and ensure hooks are active forpre-commit,post-checkout, andpre-pushstages. - The configuration is in
.pre-commit-config.yaml.
- To install and enable pre-commit hooks:
- VSCode Integration: If using VSCode, install the official Ruff extension for live linting and formatting.
- Import Sorting: Imports are sorted according to Ruff's
isortintegration (seepyproject.tomlfor config). - Line Length: 120 characters (configured in
pyproject.tomlfor Ruff and Black).
This document serves as the primary documentation. There is no separate Sphinx or MkDocs build process explicitly defined in the repository structure for user-facing documentation generation.
- Application Dependencies: To add or update a core dependency, modify the
dependencieslist inpyproject.toml. - Development Dependencies: To add or update a development-only dependency, modify
requirements_dev.txt. - Installation: After modifying dependencies, re-install them:
uv pip install -e . uv pip install -r requirements_dev.txt - Checking for Outdated Dependencies:
uv pip list --outdated
- API Key Issues:
- Ensure your API keys are correctly set in the
.envfile at the root of the repository. - Verify that
safetytooling.utils.utils.setup_environment()is called at the beginning of your script/notebook to load these keys. - If using alternative keys (e.g.,
OPENAI_API_KEY1), ensure you are specifying the correctopenai_tagoranthropic_taginsetup_environment()or passing the key directly toInferenceAPI.
- Ensure your API keys are correctly set in the
- Rate Limit Errors:
safety-toolingattempts to manage rate limits, but you might still encounter them if your usage is very high or if shared keys are exhausted.- Consider reducing
openai_num_threads,anthropic_num_threads, etc., inInferenceAPIinitialization. - For OpenAI, reduce
openai_fraction_rate_limit. - The API usage scripts (
usage_openai.py,usage_anthropic.py) can help diagnose if you are close to limits.
- Model Not Found / Invalid Model ID:
- Double-check the
model_idstring for typos. - Ensure the model is supported by the intended provider.
- For newly released models not yet hardcoded in
safety-tooling, you might need to use theforce_providerargument inInferenceAPI.__call__(). - For local VLLM deployments, ensure the
model_idpassed toInferenceAPImatches the model name the VLLM server is serving.
- Double-check the
- VLLM Server Issues:
- Check the VLLM server logs (path printed when
deploy_model_vllm_locally_autois called) for errors. - Ensure VLLM is installed (
uv pip install vllm) and you have compatible CUDA drivers. - Make sure the port VLLM is trying to use (default 8000) is not already in use.
- Check the VLLM server logs (path printed when
- HuggingFace
handler.mdfor Together AI fine-tuned models:- The
handler.mdexample is for deploying a LoRA model. Ensure thebase_modelandadapter_modelpaths in the handler script correctly point to your models on the Hugging Face Hub or the local paths within the inference endpoint's environment. - The
cache_dirin the handler might need adjustment based on the inference endpoint's writable paths.
- The
ImportErrorforipywidgetsorIPython.display:- These are needed for
TranscriptLabelingInterface. Install them viarequirements_dev.txtoruv pip install ipywidgets IPython.
- These are needed for
- FFmpeg not found (for TTS WAV conversion or some audio utils):
- Ensure FFmpeg is installed on your system and accessible in your PATH. (e.g.,
sudo apt-get install ffmpegon Debian/Ubuntu).
- Ensure FFmpeg is installed on your system and accessible in your PATH. (e.g.,