11import json
2- from functools import cache
32
4- from groq import NOT_GIVEN , BadRequestError , Groq
3+ from litellm import BadRequestError , completion , validate_environment
54
6- from src .config import GROQ_API_KEY
5+ from src .config import LLM_API_KEY , LLM_MODEL , PROVIDER
76
8- LLM_MODEL = ["llama-3.1-8b-instant" , "llama3-70b-8192" ][1 ]
97
8+ def validate_model (provider : str = PROVIDER , model : str = LLM_MODEL ) -> None :
9+ model_str = f"{ provider } /{ model } "
10+ assert not validate_environment (model_str )["keys_in_environment" ], f"Invalid value : { model } "
1011
11- @cache
12- def get_groq_client ():
13- return Groq (api_key = GROQ_API_KEY )
1412
15-
16- def ask_groq (query_text : str , system_content : str , json_schema : dict = None ) -> dict :
17- client = get_groq_client ()
18- response_format = NOT_GIVEN
13+ def ask_llm (query_text : str , system_content : str , json_schema : dict = None , provider = PROVIDER ) -> dict :
14+ response_format = None
1915 if json_schema :
2016 system_content = f"{ system_content } \n Use this json schema to reply { json .dumps (json_schema )} "
2117 response_format = {"type" : "json_object" }
2218 try :
23- chat_completion = client .chat .completions .create (
19+ # Send a message to the model
20+ print ("calling llm..." )
21+ response = completion (
22+ model = f"{ provider } /{ LLM_MODEL } " ,
23+ api_key = LLM_API_KEY ,
2424 messages = [{"role" : "system" , "content" : system_content }, {"role" : "user" , "content" : query_text }],
25- model = LLM_MODEL ,
2625 response_format = response_format ,
2726 )
28- json_str = chat_completion .choices [0 ].message .content
29- json_str = json_str .strip ()
27+ json_str = response ["choices" ][0 ]["message" ]["content" ]
3028 except BadRequestError as e :
31- failed_json = e .body ["error" ]["failed_generation" ] # noqa
32- json_str = failed_json .replace ('"""' , '"' ).replace ("\n " , "\\ n" )
29+ json_str = e .message # reusing variable for showing error
30+
31+ if json_str .startswith ("litellm" ):
32+ raise RuntimeError (json_str )
33+
3334 if not json_schema :
3435 return json_str
3536 try :
36- # json_str = re.sub("\n +", "", json_str)
3737 return json .loads (json_str )
3838 except json .JSONDecodeError as e :
3939 print (e .doc .encode (), e .pos )
@@ -47,7 +47,7 @@ def find_typos(query_text):
4747 " and list mistakes (if any) along with suggestion, not fixes."
4848 "Keep it short, no explanation. Do not modify urls, usernames and hashtags."
4949 )
50- ans = ask_groq (query_text , system_content , json_schema = {"suggestions" : ["str" ]})
50+ ans = ask_llm (query_text , system_content , json_schema = {"suggestions" : ["str" ]})
5151 return ans ["suggestions" ]
5252
5353
@@ -57,12 +57,12 @@ def fix_typos(query_text):
5757 "without any additional info like headings, footers, etc."
5858 "Do not modify urls, usernames and hashtags. `~~~` is a phrase seperator, keep it as it is."
5959 )
60- return ask_groq (query_text , system_content )
60+ return ask_llm (query_text , system_content )
6161
6262
6363if __name__ == "__main__" :
6464 from dotenv import load_dotenv # noqa
6565
6666 load_dotenv ()
67- answer = ask_groq ("Debuging is hard." , system_content = "" )
68- print (answer )
67+ answer = ask_llm ("Debuging is hard." , system_content = "" )
68+ print (f"Response from model: { answer } " )
0 commit comments