Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

from openai import OpenAI

embedding_dimensions = 256

def get_embedding(text, model="text-embedding-3-small"):

def get_embedding(text, model: str):
"""
Get embeddings from OpenAI API

Expand All @@ -21,7 +23,7 @@ def get_embedding(text, model="text-embedding-3-small"):
client = OpenAI()

# Get the embedding from OpenAI
response = client.embeddings.create(input=text, model=model)
response = client.embeddings.create(input=text, model=model, dimensions=embedding_dimensions)

# Extract and return the embedding vector
return response.data[0].embedding
Expand All @@ -39,18 +41,22 @@ def stabilize_float(x: float) -> float:
return struct.unpack("f", struct.pack("f", x))[0]


def create_embedding_object(text: str, model="text-embedding-3-small") -> dict:
def create_embedding_object(text: str) -> dict:
return create_embedding_object_model(text, "text-embedding-3-small")


def create_embedding_object_model(text: str, model: str) -> dict:
"""
Create an embedding object with metadata

Args:
text (str): Text to embed
model (str): Model to use for embedding
api_key (str, optional): OpenAI API key

Returns:
dict: Object with text, model and embedding
"""

if not os.environ.get("OPENAI_API_KEY"):
raise ValueError("OPENAI_API_KEY not provided or set in environment")

Expand Down Expand Up @@ -79,14 +85,10 @@ def main():
parser = argparse.ArgumentParser(description="Generate OpenAI embeddings and save to JSON")
parser.add_argument("text", type=str, help="Text to embed")
parser.add_argument("--output", type=str, default="embedding.json", help="Output JSON file")
parser.add_argument(
"--model", type=str, default="text-embedding-3-small", help="OpenAI embedding model to use"
)
parser.add_argument("--api-key", type=str, help="OpenAI API key (optional)")
args = parser.parse_args()

# Create embedding object
embedding_obj = create_embedding_object(args.text, model=args.model, api_key=args.api_key)
embedding_obj = create_embedding_object(args.text)

# Save to JSON file
output_path = save_embedding(embedding_obj, args.output)
Expand Down
Loading