@@ -9,6 +9,7 @@ def get_phi2(
99 sequence_length : int = 30 ,
1010 sequence_length2 : int = 3 ,
1111 dynamic_rope : bool = False ,
12+ use_dim_not_dynamic : bool = False ,
1213 ** kwargs ,
1314) -> Dict [str , Any ]:
1415 """
@@ -18,6 +19,8 @@ def get_phi2(
1819 :param sequence_length: sequence length
1920 :param sequence_length2: new sequence length
2021 :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
22+ :param use_dim_not_dynamic: uses ``torch.export.Dim`` and not a string for the batch size,
23+ the sequence length and the cache length
2124 :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
2225 :return: dictionary
2326
@@ -62,7 +65,7 @@ def get_phi2(
6265 n_layers = config ["num_hidden_layers" ]
6366 num_key_value_heads = config ["num_key_value_heads" ]
6467
65- if batch_size == 1 :
68+ if use_dim_not_dynamic :
6669 batch = torch .export .Dim ("batch" , min = 1 , max = 1024 )
6770 seq_length = torch .export .Dim ("seq_length" , min = 1 , max = 4096 )
6871 cache_length = torch .export .Dim ("cache_length" , min = 1 , max = 4096 )
0 commit comments