@@ -310,7 +310,7 @@ def to_onnx(self) -> ModelProto:
310310 return model
311311
312312
313- def flatten_iterator (obj : Any , sep : str ) -> Iterator :
313+ def _flatten_iterator (obj : Any , sep : str ) -> Iterator :
314314 """Iterates on all object."""
315315 if obj is not None :
316316 if isinstance (obj , np .ndarray ):
@@ -329,21 +329,21 @@ def flatten_iterator(obj: Any, sep: str) -> Iterator:
329329 else :
330330 for i , o in enumerate (obj ):
331331 if i == len (obj ) - 1 :
332- for p , oo in flatten_iterator (o , sep ):
332+ for p , oo in _flatten_iterator (o , sep ):
333333 yield f"tuple.{ sep } { p } " , oo
334334 else :
335- for p , oo in flatten_iterator (o , sep ):
335+ for p , oo in _flatten_iterator (o , sep ):
336336 yield f"tuple{ sep } { p } " , oo
337337 elif isinstance (obj , list ):
338338 if not obj :
339339 yield f"list.{ sep } empty" , None
340340 else :
341341 for i , o in enumerate (obj ):
342342 if i == len (obj ) - 1 :
343- for p , oo in flatten_iterator (o , sep ):
343+ for p , oo in _flatten_iterator (o , sep ):
344344 yield f"list.{ sep } { p } " , oo
345345 else :
346- for p , oo in flatten_iterator (o , sep ):
346+ for p , oo in _flatten_iterator (o , sep ):
347347 yield f"list{ sep } { p } " , oo
348348 elif isinstance (obj , dict ):
349349 if not obj :
@@ -352,13 +352,13 @@ def flatten_iterator(obj: Any, sep: str) -> Iterator:
352352 for i , (k , v ) in enumerate (obj .items ()):
353353 assert sep not in k , (
354354 f"Key { k !r} cannot contain '{ sep } '. "
355- f"It would interfer with the serialization."
355+ f"It would interfere with the serialization."
356356 )
357357 if i == len (obj ) - 1 :
358- for p , o in flatten_iterator (v , sep ):
358+ for p , o in _flatten_iterator (v , sep ):
359359 yield f"dict._{ k } { sep } { p } " , o
360360 else :
361- for p , o in flatten_iterator (v , sep ):
361+ for p , o in _flatten_iterator (v , sep ):
362362 yield f"dict_{ k } { sep } { p } " , o
363363 elif obj .__class__ .__name__ == "DynamicCache" :
364364 # transformers
@@ -370,10 +370,10 @@ def flatten_iterator(obj: Any, sep: str) -> Iterator:
370370 atts = ["key_cache" , "value_cache" ]
371371 for i , att in enumerate (atts ):
372372 if i == len (atts ) - 1 :
373- for p , o in flatten_iterator (getattr (obj , att ), sep ):
373+ for p , o in _flatten_iterator (getattr (obj , att ), sep ):
374374 yield f"DynamicCache._{ att } { sep } { p } " , o
375375 else :
376- for p , o in flatten_iterator (getattr (obj , att ), sep ):
376+ for p , o in _flatten_iterator (getattr (obj , att ), sep ):
377377 yield f"DynamicCache_{ att } { sep } { p } " , o
378378 else :
379379 raise NotImplementedError (f"Unexpected type { type (obj )} " )
@@ -403,7 +403,7 @@ def create_onnx_model_from_input_tensors(
403403 switch_low_high = sys .byteorder != "big"
404404
405405 builder = MiniOnnxBuilder (sep = sep )
406- for prefix , o in flatten_iterator (inputs , sep ):
406+ for prefix , o in _flatten_iterator (inputs , sep ):
407407 if o is None :
408408 builder .append_output_initializer (prefix , np .array ([]))
409409 else :
@@ -413,17 +413,15 @@ def create_onnx_model_from_input_tensors(
413413 return model
414414
415415
416- def unflatten (
416+ def _unflatten (
417417 sep : str ,
418418 names : List [str ],
419419 outputs : List [Any ],
420420 pos : int = 0 ,
421421 level : int = 0 ,
422422 device : str = "cpu" ,
423423) -> Tuple [int , Tuple [Any , ...]]:
424- """
425- Unflattens a list of outputs flattened with :func:`flatten_iterator`.
426- """
424+ """Unflattens a list of outputs flattened with :func:`flatten_iterator`."""
427425 name = names [pos ]
428426 spl = name .split (sep )
429427 if len (spl ) == level + 1 :
@@ -448,7 +446,7 @@ def unflatten(
448446 name = names [pos ]
449447 spl = name .split (sep )
450448 prefix = spl [level ]
451- next_pos , value = unflatten (
449+ next_pos , value = _unflatten (
452450 sep , names , outputs , pos = pos , level = level + 1 , device = device
453451 )
454452
@@ -499,7 +497,7 @@ def create_input_tensors_from_onnx_model(
499497 device : str = "cpu" ,
500498 engine : str = "ExtendedReferenceEvaluator" ,
501499 sep : str = "___" ,
502- ) -> Union [ Tuple [ Any , ...], Dict [ str , Any ]] :
500+ ) -> Any :
503501 """
504502 Deserializes tensors stored with function
505503 :func:`create_onnx_model_from_input_tensors`.
@@ -511,7 +509,7 @@ def create_input_tensors_from_onnx_model(
511509 :param device: moves the tensor to this device
512510 :param engine: runtime to use, onnx, the default value, onnxruntime
513511 :param sep: separator
514- :return: ModelProto
512+ :return: restored data
515513 """
516514 if engine == "ExtendedReferenceEvaluator" :
517515 from ..reference import ExtendedReferenceEvaluator
@@ -552,4 +550,4 @@ def create_input_tensors_from_onnx_model(
552550 return torch .from_numpy (output ).to (device )
553551 raise AssertionError (f"Unexpected name { name !r} in { names } " )
554552
555- return unflatten (sep , names , got , device = device )[1 ]
553+ return _unflatten (sep , names , got , device = device )[1 ]
0 commit comments