@@ -343,65 +343,23 @@ def string_type(
343343
344344 # others classes
345345
346- if type (obj ).__name__ == "MambaCache" :
347- c = string_type (
348- obj .conv_states ,
349- with_shape = with_shape ,
350- with_min_max = with_min_max ,
351- with_device = with_device ,
352- limit = limit ,
353- )
354- d = string_type (
355- obj .ssm_states ,
346+ if obj .__class__ in torch .utils ._pytree .SUPPORTED_NODES :
347+ args , _spec = torch .utils ._pytree .tree_flatten (obj )
348+ att = string_type (
349+ args ,
356350 with_shape = with_shape ,
357351 with_min_max = with_min_max ,
358352 with_device = with_device ,
359353 limit = limit ,
360354 )
361- return f"MambaCache(conv_states={ c } , ssm_states={ d } )"
355+ return f"{ obj .__class__ .__name__ } [serialized]({ att } )"
356+
362357 if type (obj ).__name__ == "Node" and hasattr (obj , "meta" ):
363358 # torch.fx.node.Node
364359 return f"%{ obj .target } "
365360 if type (obj ).__name__ == "ValueInfoProto" :
366361 return f"OT{ obj .type .tensor_type .elem_type } "
367362
368- if obj .__class__ .__name__ == "DynamicCache" :
369- kc = string_type (
370- obj .key_cache ,
371- with_shape = with_shape ,
372- with_min_max = with_min_max ,
373- with_device = with_device ,
374- limit = limit ,
375- )
376- vc = string_type (
377- obj .value_cache ,
378- with_shape = with_shape ,
379- with_min_max = with_min_max ,
380- with_device = with_device ,
381- limit = limit ,
382- )
383- return f"{ obj .__class__ .__name__ } (key_cache={ kc } , value_cache={ vc } )"
384-
385- if obj .__class__ .__name__ == "EncoderDecoderCache" :
386- att = string_type (
387- obj .self_attention_cache ,
388- with_shape = with_shape ,
389- with_min_max = with_min_max ,
390- with_device = with_device ,
391- limit = limit ,
392- )
393- cross = string_type (
394- obj .cross_attention_cache ,
395- with_shape = with_shape ,
396- with_min_max = with_min_max ,
397- with_device = with_device ,
398- limit = limit ,
399- )
400- return (
401- f"{ obj .__class__ .__name__ } (self_attention_cache={ att } , "
402- f"cross_attention_cache={ cross } )"
403- )
404-
405363 if obj .__class__ .__name__ == "BatchFeature" :
406364 s = string_type (
407365 obj .data ,
@@ -440,19 +398,64 @@ def string_type(
440398 if isinstance (obj , torch .utils ._pytree .TreeSpec ):
441399 return repr (obj ).replace (" " , "" ).replace ("\n " , " " )
442400
443- if ignore :
444- return f"{ obj .__class__ .__name__ } (...)"
401+ # to avoid failures
445402
446- if obj .__class__ in torch .utils ._pytree .SUPPORTED_NODES :
447- args , _spec = torch .utils ._pytree .tree_flatten (obj )
403+ if type (obj ).__name__ == "MambaCache" :
404+ c = string_type (
405+ obj .conv_states ,
406+ with_shape = with_shape ,
407+ with_min_max = with_min_max ,
408+ with_device = with_device ,
409+ limit = limit ,
410+ )
411+ d = string_type (
412+ obj .ssm_states ,
413+ with_shape = with_shape ,
414+ with_min_max = with_min_max ,
415+ with_device = with_device ,
416+ limit = limit ,
417+ )
418+ return f"MambaCache(conv_states={ c } , ssm_states={ d } )"
419+
420+ if obj .__class__ .__name__ == "DynamicCache" :
421+ kc = string_type (
422+ obj .key_cache ,
423+ with_shape = with_shape ,
424+ with_min_max = with_min_max ,
425+ with_device = with_device ,
426+ limit = limit ,
427+ )
428+ vc = string_type (
429+ obj .value_cache ,
430+ with_shape = with_shape ,
431+ with_min_max = with_min_max ,
432+ with_device = with_device ,
433+ limit = limit ,
434+ )
435+ return f"{ obj .__class__ .__name__ } (key_cache={ kc } , value_cache={ vc } )"
436+
437+ if obj .__class__ .__name__ == "EncoderDecoderCache" :
448438 att = string_type (
449- args ,
439+ obj . self_attention_cache ,
450440 with_shape = with_shape ,
451441 with_min_max = with_min_max ,
452442 with_device = with_device ,
453443 limit = limit ,
454444 )
455- return f"{ obj .__class__ .__name__ } ({ att } )"
445+ cross = string_type (
446+ obj .cross_attention_cache ,
447+ with_shape = with_shape ,
448+ with_min_max = with_min_max ,
449+ with_device = with_device ,
450+ limit = limit ,
451+ )
452+ return (
453+ f"{ obj .__class__ .__name__ } (self_attention_cache={ att } , "
454+ f"cross_attention_cache={ cross } )"
455+ )
456+
457+ if ignore :
458+ return f"{ obj .__class__ .__name__ } (...)"
456459
457460 raise AssertionError (f"Unsupported type { type (obj ).__name__ !r} - { type (obj )} " )
458461
0 commit comments