@@ -85,6 +85,7 @@ class ModelArgs:
8585 n_kv_heads : Optional [int ] = None
8686 vocab_size : int = - 1 # defined later by tokenizer
8787 hidden_dim : Optional [int ] = None
88+ head_dim : Optional [int ] = None # Optional customized head_dim
8889 multiple_of : int = 256 # make SwiGLU hidden layer size multiple of large power of 2
8990 ffn_dim_multiplier : Optional [float ] = None
9091 norm_eps : float = 1e-5
@@ -142,6 +143,9 @@ def __post_init__(self):
142143 hidden_dim = int (self .ffn_dim_multiplier * hidden_dim )
143144 self .hidden_dim = find_multiple (hidden_dim , multiple_of )
144145
146+ if self .head_dim is None :
147+ self .head_dim = self .dim // self .n_heads
148+
145149
146150class KVCache (nn .Module ):
147151 def __init__ (
@@ -272,7 +276,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
272276 self .n_local_heads = self .n_heads // model_parallel_size
273277 self .n_local_kv_heads = self .n_kv_heads // model_parallel_size
274278 self .n_rep = self .n_local_heads // self .n_local_kv_heads
275- self .head_dim = args .dim // self . n_heads
279+ self .head_dim = args .head_dim
276280 self .max_batch_size = args .max_batch_size
277281 self .max_seq_len = args .max_seq_len
278282 self .dim = args .dim
@@ -304,7 +308,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
304308 )
305309 self .SDPA = SDPA (
306310 kv_cache = self .kv_cache ,
307- dim = self .dim ,
311+ dim = self .n_local_heads * self . head_dim ,
308312 head_dim = self .head_dim ,
309313 n_rep = self .n_rep ,
310314 max_seq_len = self .max_seq_len ,
@@ -425,7 +429,7 @@ def __init__(self, layer_id: int, args: ModelArgs):
425429 self .use_kv_cache = args .use_kv_cache
426430 self .n_heads = args .n_heads
427431 self .dim = args .dim
428- self .head_dim = args .dim // args . n_heads
432+ self .head_dim = args .head_dim
429433 self .attention = Attention (args , layer_id )
430434 if args .moe :
431435 self .block_sparse_moe = MOEFeedForward (args )
@@ -472,7 +476,7 @@ def __init__(self, params: ModelArgs):
472476 precompute_freqs_cis , use_scaled = params .use_scaled_rope
473477 )
474478 freqs_cos , freqs_sin = self .precompute_freqs_cis (
475- params .dim // params . n_heads ,
479+ params .head_dim ,
476480 (
477481 params .max_seq_len # Normal llama2.
478482 if params .ffn_dim_multiplier is None
0 commit comments