Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions graph_of_thoughts/language_models/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import random
import time
from typing import List, Dict, Union
from openai import OpenAI, OpenAIError
import OpenAI
from openai.types.chat.chat_completion import ChatCompletion

from .abstract_language_model import AbstractLanguageModel
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(
if self.api_key == "":
raise ValueError("OPENAI_API_KEY is not set")
# Initialize the OpenAI Client
self.client = OpenAI(api_key=self.api_key, organization=self.organization)
self.client = openai

def query(
self, query: str, num_responses: int = 1
Expand Down Expand Up @@ -101,7 +101,7 @@ def query(
self.response_cache[query] = response
return response

@backoff.on_exception(backoff.expo, OpenAIError, max_time=10, max_tries=6)
@backoff.on_exception(backoff.expo, openai.OpenAIError, max_time=10, max_tries=6)
def chat(self, messages: List[Dict], num_responses: int = 1) -> ChatCompletion:
"""
Send chat messages to the OpenAI model and retrieves the model's response.
Expand All @@ -114,7 +114,7 @@ def chat(self, messages: List[Dict], num_responses: int = 1) -> ChatCompletion:
:return: The OpenAI model's response.
:rtype: ChatCompletion
"""
response = self.client.chat.completions.create(
response = self.client.ChatCompletions.create(
model=self.model_id,
messages=messages,
temperature=self.temperature,
Expand All @@ -123,8 +123,8 @@ def chat(self, messages: List[Dict], num_responses: int = 1) -> ChatCompletion:
stop=self.stop,
)

self.prompt_tokens += response.usage.prompt_tokens
self.completion_tokens += response.usage.completion_tokens
self.prompt_tokens += response["usage"]["prompt_tokens"]
self.completion_tokens += response["usage"]["completion_tokens"]
prompt_tokens_k = float(self.prompt_tokens) / 1000.0
completion_tokens_k = float(self.completion_tokens) / 1000.0
self.cost = (
Expand All @@ -151,7 +151,8 @@ def get_response_texts(
if not isinstance(query_response, List):
query_response = [query_response]
return [
choice.message.content
choice["message"]["content"]
for response in query_response
for choice in response.choices
for choice in response["choices"]
]