File tree Expand file tree Collapse file tree 3 files changed +5
-5
lines changed
Expand file tree Collapse file tree 3 files changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -1014,7 +1014,7 @@ def assertEqualArrayAny(
10141014 msg_ = "\n " .join (excs )
10151015 msg = f"{ msg } \n { msg_ } " if msg else msg_
10161016 raise AssertionError (f"Found { len (excs )} discrepancies\n { msg } " )
1017- elif expected .__class__ .__name__ == "DynamicCache" :
1017+ elif expected .__class__ .__name__ in ( "DynamicCache" , "StaticCache" ) :
10181018 atts = {"key_cache" , "value_cache" }
10191019 self .assertEqualArrayAny (
10201020 {k : expected .__dict__ .get (k , None ) for k in atts },
Original file line number Diff line number Diff line change @@ -174,7 +174,7 @@ def get_inputs(
174174 shapes = {
175175 "input_ids" : {0 : batch , 1 : seq_length },
176176 "attention_mask" : {0 : batch , 2 : "seq" },
177- "cache_position" : {1 : "seq" },
177+ "cache_position" : {0 : "seq" },
178178 "past_key_values" : [
179179 [{0 : batch , 2 : cache_length } for _ in range (num_hidden_layers )],
180180 [{0 : batch , 2 : cache_length } for _ in range (num_hidden_layers )],
Original file line number Diff line number Diff line change @@ -57,13 +57,13 @@ def get_tiny_llm(
5757 res = get_inputs (
5858 model ,
5959 conf ,
60- dummy_max_token_id = config ["vocab_size" ],
61- num_hidden_layers = config ["num_hidden_layers" ],
60+ dummy_max_token_id = config ["vocab_size" ], # type: ignore[arg-type]
61+ num_hidden_layers = config ["num_hidden_layers" ], # type: ignore[arg-type]
6262 batch_size = batch_size ,
6363 sequence_length = sequence_length ,
6464 sequence_length2 = sequence_length2 ,
6565 dynamic_rope = dynamic_rope ,
66- num_key_value_heads = config ["num_key_value_heads" ],
66+ num_key_value_heads = config ["num_key_value_heads" ], # type: ignore[arg-type]
6767 cls_cache = "StaticCache" if use_static_cache else "DynamicCache" ,
6868 )
6969
You can’t perform that action at this time.
0 commit comments