diff --git a/graph_of_thoughts/language_models/chatgpt.py b/graph_of_thoughts/language_models/chatgpt.py index 52dc1d3..f81e2e3 100644 --- a/graph_of_thoughts/language_models/chatgpt.py +++ b/graph_of_thoughts/language_models/chatgpt.py @@ -11,7 +11,7 @@ import random import time from typing import List, Dict, Union -from openai import OpenAI, OpenAIError +import OpenAI from openai.types.chat.chat_completion import ChatCompletion from .abstract_language_model import AbstractLanguageModel @@ -58,7 +58,7 @@ def __init__( if self.api_key == "": raise ValueError("OPENAI_API_KEY is not set") # Initialize the OpenAI Client - self.client = OpenAI(api_key=self.api_key, organization=self.organization) + self.client = openai def query( self, query: str, num_responses: int = 1 @@ -101,7 +101,7 @@ def query( self.response_cache[query] = response return response - @backoff.on_exception(backoff.expo, OpenAIError, max_time=10, max_tries=6) + @backoff.on_exception(backoff.expo, openai.OpenAIError, max_time=10, max_tries=6) def chat(self, messages: List[Dict], num_responses: int = 1) -> ChatCompletion: """ Send chat messages to the OpenAI model and retrieves the model's response. @@ -114,7 +114,7 @@ def chat(self, messages: List[Dict], num_responses: int = 1) -> ChatCompletion: :return: The OpenAI model's response. :rtype: ChatCompletion """ - response = self.client.chat.completions.create( + response = self.client.ChatCompletions.create( model=self.model_id, messages=messages, temperature=self.temperature, @@ -123,8 +123,8 @@ def chat(self, messages: List[Dict], num_responses: int = 1) -> ChatCompletion: stop=self.stop, ) - self.prompt_tokens += response.usage.prompt_tokens - self.completion_tokens += response.usage.completion_tokens + self.prompt_tokens += response["usage"]["prompt_tokens"] + self.completion_tokens += response["usage"]["completion_tokens"] prompt_tokens_k = float(self.prompt_tokens) / 1000.0 completion_tokens_k = float(self.completion_tokens) / 1000.0 self.cost = ( @@ -151,7 +151,8 @@ def get_response_texts( if not isinstance(query_response, List): query_response = [query_response] return [ - choice.message.content + choice["message"]["content"] for response in query_response - for choice in response.choices + for choice in response["choices"] ] +