1515
1616from __future__ import annotations
1717
18+ import hashlib
1819import logging
1920import posixpath
2021import threading
22+ from collections import OrderedDict
2123from concurrent .futures import Future , ThreadPoolExecutor
2224from contextlib import ExitStack
2325from dataclasses import asdict , dataclass
@@ -73,8 +75,16 @@ class CompletionRefs:
7375
7476JsonEncodeable = list [dict [str , Any ]]
7577
76- # mapping of upload path to function computing upload data dict
77- UploadData = dict [str , Callable [[], JsonEncodeable ]]
78+ # mapping of upload path and whether the contents were hashed to the filename to function computing upload data dict
79+ UploadData = dict [tuple [str , bool ], Callable [[], JsonEncodeable ]]
80+
81+
82+ def is_system_instructions_hashable (
83+ system_instruction : list [types .MessagePart ] | None ,
84+ ) -> bool :
85+ return bool (system_instruction ) and all (
86+ isinstance (x , types .Text ) for x in system_instruction
87+ )
7888
7989
8090class UploadCompletionHook (CompletionHook ):
@@ -97,10 +107,13 @@ def __init__(
97107 base_path : str ,
98108 max_size : int = 20 ,
99109 upload_format : Format | None = None ,
110+ lru_cache_max_size : int = 1024 ,
100111 ) -> None :
101112 self ._max_size = max_size
102113 self ._fs , base_path = fsspec .url_to_fs (base_path )
103114 self ._base_path = self ._fs .unstrip_protocol (base_path )
115+ self .lru_dict : OrderedDict [str , bool ] = OrderedDict ()
116+ self .lru_cache_max_size = lru_cache_max_size
104117
105118 if upload_format not in _FORMATS + (None ,):
106119 raise ValueError (
@@ -132,7 +145,10 @@ def done(future: Future[None]) -> None:
132145 finally :
133146 self ._semaphore .release ()
134147
135- for path , json_encodeable in upload_data .items ():
148+ for (
149+ path ,
150+ contents_hashed_to_filename ,
151+ ), json_encodeable in upload_data .items ():
136152 # could not acquire, drop data
137153 if not self ._semaphore .acquire (blocking = False ): # pylint: disable=consider-using-with
138154 _logger .warning (
@@ -143,7 +159,10 @@ def done(future: Future[None]) -> None:
143159
144160 try :
145161 fut = self ._executor .submit (
146- self ._do_upload , path , json_encodeable
162+ self ._do_upload ,
163+ path ,
164+ contents_hashed_to_filename ,
165+ json_encodeable ,
147166 )
148167 fut .add_done_callback (done )
149168 except RuntimeError :
@@ -152,10 +171,20 @@ def done(future: Future[None]) -> None:
152171 )
153172 self ._semaphore .release ()
154173
155- def _calculate_ref_path (self ) -> CompletionRefs :
174+ def _calculate_ref_path (
175+ self , system_instruction : list [types .MessagePart ]
176+ ) -> CompletionRefs :
156177 # TODO: experimental with using the trace_id and span_id, or fetching
157178 # gen_ai.response.id from the active span.
158-
179+ system_instruction_hash = None
180+ if is_system_instructions_hashable (system_instruction ):
181+ # Get a hash of the text.
182+ system_instruction_hash = hashlib .sha256 (
183+ "\n " .join (x .content for x in system_instruction ).encode ( # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue, reportUnknownArgumentType]
184+ "utf-8"
185+ ),
186+ usedforsecurity = False ,
187+ ).hexdigest ()
159188 uuid_str = str (uuid4 ())
160189 return CompletionRefs (
161190 inputs_ref = posixpath .join (
@@ -166,13 +195,32 @@ def _calculate_ref_path(self) -> CompletionRefs:
166195 ),
167196 system_instruction_ref = posixpath .join (
168197 self ._base_path ,
169- f"{ uuid_str } _system_instruction.{ self ._format } " ,
198+ f"{ system_instruction_hash or uuid_str } _system_instruction.{ self ._format } " ,
170199 ),
171200 )
172201
202+ def _file_exists (self , path : str ) -> bool :
203+ if path in self .lru_dict :
204+ self .lru_dict .move_to_end (path )
205+ return True
206+ # https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.exists
207+ file_exists = self ._fs .exists (path )
208+ # don't cache this because soon the file will exist..
209+ if not file_exists :
210+ return False
211+ self .lru_dict [path ] = True
212+ if len (self .lru_dict ) > self .lru_cache_max_size :
213+ self .lru_dict .popitem (last = False )
214+ return True
215+
173216 def _do_upload (
174- self , path : str , json_encodeable : Callable [[], JsonEncodeable ]
217+ self ,
218+ path : str ,
219+ contents_hashed_to_filename : bool ,
220+ json_encodeable : Callable [[], JsonEncodeable ],
175221 ) -> None :
222+ if contents_hashed_to_filename and self ._file_exists (path ):
223+ return
176224 if self ._format == "json" :
177225 # output as a single line with the json messages array
178226 message_lines = [json_encodeable ()]
@@ -194,6 +242,11 @@ def _do_upload(
194242 gen_ai_json_dump (message , file )
195243 file .write ("\n " )
196244
245+ if contents_hashed_to_filename :
246+ self .lru_dict [path ] = True
247+ if len (self .lru_dict ) > self .lru_cache_max_size :
248+ self .lru_dict .popitem (last = False )
249+
197250 def on_completion (
198251 self ,
199252 * ,
@@ -213,7 +266,7 @@ def on_completion(
213266 system_instruction = system_instruction or None ,
214267 )
215268 # generate the paths to upload to
216- ref_names = self ._calculate_ref_path ()
269+ ref_names = self ._calculate_ref_path (system_instruction )
217270
218271 def to_dict (
219272 dataclass_list : list [types .InputMessage ]
@@ -223,35 +276,40 @@ def to_dict(
223276 return [asdict (dc ) for dc in dataclass_list ]
224277
225278 references = [
226- (ref_name , ref , ref_attr )
227- for ref_name , ref , ref_attr in [
279+ (ref_name , ref , ref_attr , contents_hashed_to_filename )
280+ for ref_name , ref , ref_attr , contents_hashed_to_filename in [
228281 (
229282 ref_names .inputs_ref ,
230283 completion .inputs ,
231284 GEN_AI_INPUT_MESSAGES_REF ,
285+ False ,
232286 ),
233287 (
234288 ref_names .outputs_ref ,
235289 completion .outputs ,
236290 GEN_AI_OUTPUT_MESSAGES_REF ,
291+ False ,
237292 ),
238293 (
239294 ref_names .system_instruction_ref ,
240295 completion .system_instruction ,
241296 GEN_AI_SYSTEM_INSTRUCTIONS_REF ,
297+ is_system_instructions_hashable (
298+ completion .system_instruction
299+ ),
242300 ),
243301 ]
244- if ref
302+ if ref # Filter out empty input/output/sys instruction
245303 ]
246304 self ._submit_all (
247305 {
248- ref_name : partial (to_dict , ref )
249- for ref_name , ref , _ in references
306+ ( ref_name , contents_hashed_to_filename ) : partial (to_dict , ref )
307+ for ref_name , ref , _ , contents_hashed_to_filename in references
250308 }
251309 )
252310
253311 # stamp the refs on telemetry
254- references = {ref_attr : name for name , _ , ref_attr in references }
312+ references = {ref_attr : name for name , _ , ref_attr , _ in references }
255313 if span :
256314 span .set_attributes (references )
257315 if log_record :
0 commit comments