1- import os
2- import uuid
3- import ujson
41import functools
5- from pathlib import Path
2+ import os
3+ import uuid
64from datetime import datetime
5+ from pathlib import Path
76
8- try :
9- import warnings
10- with warnings .catch_warnings ():
11- warnings .filterwarnings ("ignore" , category = UserWarning )
12- if "LITELLM_LOCAL_MODEL_COST_MAP" not in os .environ :
13- os .environ ["LITELLM_LOCAL_MODEL_COST_MAP" ] = "True"
14- import litellm
15- litellm .telemetry = False
7+ import litellm
8+ import ujson
9+ from litellm .caching import Cache
1610
17- from litellm . caching import Cache
18- disk_cache_dir = os . environ . get ( 'DSPY_CACHEDIR' ) or os . path . join ( Path . home (), '.dspy_cache' )
19- litellm .cache = Cache ( disk_cache_dir = disk_cache_dir , type = "disk" )
11+ disk_cache_dir = os . environ . get ( "DSPY_CACHEDIR" ) or os . path . join ( Path . home (), ".dspy_cache" )
12+ litellm . cache = Cache ( disk_cache_dir = disk_cache_dir , type = "disk" )
13+ litellm .telemetry = False
2014
21- except ImportError :
22- class LitellmPlaceholder :
23- def __getattr__ (self , _ ): raise ImportError ("The LiteLLM package is not installed. Run `pip install litellm`." )
15+ if "LITELLM_LOCAL_MODEL_COST_MAP" not in os .environ :
16+ os .environ ["LITELLM_LOCAL_MODEL_COST_MAP" ] = "True"
2417
25- litellm = LitellmPlaceholder ()
2618
2719class LM :
28- def __init__ (self , model , model_type = ' chat' , temperature = 0.0 , max_tokens = 1000 , cache = True , ** kwargs ):
20+ def __init__ (self , model , model_type = " chat" , temperature = 0.0 , max_tokens = 1000 , cache = True , ** kwargs ):
2921 self .model = model
3022 self .model_type = model_type
3123 self .cache = cache
3224 self .kwargs = dict (temperature = temperature , max_tokens = max_tokens , ** kwargs )
3325 self .history = []
3426
3527 if "o1-" in model :
36- assert max_tokens >= 5000 and temperature == 1.0 , \
37- "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`"
38-
39-
28+ assert (
29+ max_tokens >= 5000 and temperature == 1.0
30+ ), "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`"
31+
4032 def __call__ (self , prompt = None , messages = None , ** kwargs ):
4133 # Build the request.
4234 cache = kwargs .pop ("cache" , self .cache )
4335 messages = messages or [{"role" : "user" , "content" : prompt }]
4436 kwargs = {** self .kwargs , ** kwargs }
4537
4638 # Make the request and handle LRU & disk caching.
47- if self .model_type == "chat" : completion = cached_litellm_completion if cache else litellm_completion
48- else : completion = cached_litellm_text_completion if cache else litellm_text_completion
39+ if self .model_type == "chat" :
40+ completion = cached_litellm_completion if cache else litellm_completion
41+ else :
42+ completion = cached_litellm_text_completion if cache else litellm_text_completion
4943
5044 response = completion (ujson .dumps (dict (model = self .model , messages = messages , ** kwargs )))
5145 outputs = [c .message .content if hasattr (c , "message" ) else c ["text" ] for c in response ["choices" ]]
@@ -63,8 +57,9 @@ def __call__(self, prompt=None, messages=None, **kwargs):
6357 model_type = self .model_type ,
6458 )
6559 self .history .append (entry )
60+
6661 return outputs
67-
62+
6863 def inspect_history (self , n : int = 1 ):
6964 _inspect_history (self , n )
7065
@@ -73,14 +68,17 @@ def inspect_history(self, n: int = 1):
7368def cached_litellm_completion (request ):
7469 return litellm_completion (request , cache = {"no-cache" : False , "no-store" : False })
7570
71+
7672def litellm_completion (request , cache = {"no-cache" : True , "no-store" : True }):
7773 kwargs = ujson .loads (request )
7874 return litellm .completion (cache = cache , ** kwargs )
7975
76+
8077@functools .lru_cache (maxsize = None )
8178def cached_litellm_text_completion (request ):
8279 return litellm_text_completion (request , cache = {"no-cache" : False , "no-store" : False })
8380
81+
8482def litellm_text_completion (request , cache = {"no-cache" : True , "no-store" : True }):
8583 kwargs = ujson .loads (request )
8684
@@ -93,32 +91,40 @@ def litellm_text_completion(request, cache={"no-cache": True, "no-store": True})
9391 api_base = kwargs .pop ("api_base" , None ) or os .getenv (f"{ provider } _API_BASE" )
9492
9593 # Build the prompt from the messages.
96- prompt = ' \n \n ' .join ([x [' content' ] for x in kwargs .pop ("messages" )] + [' BEGIN RESPONSE:' ])
94+ prompt = " \n \n " .join ([x [" content" ] for x in kwargs .pop ("messages" )] + [" BEGIN RESPONSE:" ])
9795
98- return litellm .text_completion (cache = cache , model = f'text-completion-openai/{ model } ' , api_key = api_key ,
99- api_base = api_base , prompt = prompt , ** kwargs )
96+ return litellm .text_completion (
97+ cache = cache ,
98+ model = f"text-completion-openai/{ model } " ,
99+ api_key = api_key ,
100+ api_base = api_base ,
101+ prompt = prompt ,
102+ ** kwargs ,
103+ )
100104
101105
102106def _green (text : str , end : str = "\n " ):
103107 return "\x1b [32m" + str (text ).lstrip () + "\x1b [0m" + end
104108
109+
105110def _red (text : str , end : str = "\n " ):
106111 return "\x1b [31m" + str (text ) + "\x1b [0m" + end
107112
113+
108114def _inspect_history (lm , n : int = 1 ):
109115 """Prints the last n prompts and their completions."""
110116
111117 for item in lm .history [- n :]:
112- messages = item ["messages" ] or [{"role" : "user" , "content" : item [' prompt' ]}]
118+ messages = item ["messages" ] or [{"role" : "user" , "content" : item [" prompt" ]}]
113119 outputs = item ["outputs" ]
114120 timestamp = item .get ("timestamp" , "Unknown time" )
115121
116122 print ("\n \n \n " )
117123 print ("\x1b [34m" + f"[{ timestamp } ]" + "\x1b [0m" + "\n " )
118-
124+
119125 for msg in messages :
120126 print (_red (f"{ msg ['role' ].capitalize ()} message:" ))
121- print (msg [' content' ].strip ())
127+ print (msg [" content" ].strip ())
122128 print ("\n " )
123129
124130 print (_red ("Response:" ))
@@ -127,5 +133,5 @@ def _inspect_history(lm, n: int = 1):
127133 if len (outputs ) > 1 :
128134 choices_text = f" \t (and { len (outputs )- 1 } other completions)"
129135 print (_red (choices_text , end = "" ))
130-
131- print ("\n \n \n " )
136+
137+ print ("\n \n \n " )
0 commit comments