@@ -100,37 +100,52 @@ public struct Gemma3Configuration: Decodable, Sendable {
100100 // Gemma 3 models have consistent head_dim of 256
101101 headDim = getOptionalValue ( . headDim, type: Int . self) ?? 256
102102
103- // num_attention_heads might not be in VLM configs - calculate from hidden_size/head_dim
104- // Gemma 3 sizes: 1B=4 heads, 4B=10 heads, 12B=16 heads, 27B=32 heads
103+ // num_attention_heads and num_key_value_heads depend on model size
104+ // VLM configs often don't include these, so we infer from hidden_size
105+ // Known Gemma 3 configurations:
106+ // 1B: hidden=1152, heads=4, kv_heads=1
107+ // 4B: hidden=2560, heads=8, kv_heads=4
108+ // 12B: hidden=3840, heads=16, kv_heads=8
109+ // 27B: hidden=5120, heads=32, kv_heads=16
105110 if let heads = getOptionalValue ( . numAttentionHeads, type: Int . self) {
106111 numAttentionHeads = heads
107112 } else {
108- numAttentionHeads = hiddenSize / headDim
113+ // Infer from hidden_size for VLM configs
114+ switch hiddenSize {
115+ case 1152 : numAttentionHeads = 4
116+ case 2560 : numAttentionHeads = 8
117+ case 3840 : numAttentionHeads = 16
118+ case 5120 : numAttentionHeads = 32
119+ default : numAttentionHeads = hiddenSize / headDim // Fallback
120+ }
109121 }
110122
111- // num_key_value_heads defaults to 1 for Gemma 3 (extreme GQA)
112- numKeyValueHeads = getOptionalValue ( . numKeyValueHeads, type: Int . self) ?? 1
123+ if let kvHeads = getOptionalValue ( . numKeyValueHeads, type: Int . self) {
124+ numKeyValueHeads = kvHeads
125+ } else {
126+ // Infer from hidden_size for VLM configs
127+ switch hiddenSize {
128+ case 1152 : numKeyValueHeads = 1
129+ case 2560 : numKeyValueHeads = 4
130+ case 3840 : numKeyValueHeads = 8
131+ case 5120 : numKeyValueHeads = 16
132+ default : numKeyValueHeads = max ( 1 , numAttentionHeads / 4 ) // Fallback
133+ }
134+ }
113135
114- // vocab_size for Gemma 3 is 262144
115- vocabSize = getOptionalValue ( . vocabSize, type: Int . self) ?? 262144
136+ // vocab_size for Gemma 3 is 262208 (includes special tokens for VLM)
137+ vocabSize = getOptionalValue ( . vocabSize, type: Int . self) ?? 262208
116138
117139 rmsNormEps = getOptionalValue ( . rmsNormEps, type: Float . self) ?? 1e-6
140+
141+ // Gemma 3 RoPE theta defaults from HuggingFace transformers:
142+ // - Global attention: 1,000,000
143+ // - Sliding/local attention: 10,000
144+ ropeTheta = getOptionalValue ( . ropeTheta, type: Float . self) ?? 1000000.0
118145 ropeLocalTheta = getOptionalValue ( . ropeLocalTheta, type: Float . self) ?? 10000.0
119146
120- // Check for rope_scaling first to determine appropriate default
147+ // Check for rope_scaling (used in 4B+ models for extended context)
121148 ropeScaling = getOptionalValue ( . ropeScaling, type: [ String : StringOrNumber ] . self)
122-
123- // rope_theta: 1B uses 1000000, 4B+ with linear scaling should use 10000 as base
124- // (the linear scaling extends the effective context)
125- if let specifiedTheta = getOptionalValue ( . ropeTheta, type: Float . self) {
126- ropeTheta = specifiedTheta
127- } else if ropeScaling != nil {
128- // Models with rope_scaling typically use lower base theta
129- ropeTheta = 10000.0
130- } else {
131- // Gemma 3 1B default
132- ropeTheta = 1000000.0
133- }
134149 maxPositionEmbeddings = getOptionalValue ( . maxPositionEmbeddings, type: Int . self) ?? 32768
135150 slidingWindow = getOptionalValue ( . slidingWindow, type: Int . self) ?? 512
136151 // Default to 0 (no pattern = all layers same) for models like 4B that don't specify it
@@ -355,18 +370,39 @@ class Gemma3ModelInner: Module {
355370 let scale = MLXArray ( sqrt ( Float ( hiddenSize) ) )
356371 hiddenStates = hiddenStates * scale. asType ( hiddenStates. dtype)
357372
358- // Create masks for global and sliding window attention
359- // Find a global layer to get its cache offset for the full mask
360- let globalLayerIdx = slidingWindowPattern > 0 ? slidingWindowPattern - 1 : 0
361- let globalCache = globalLayerIdx < cache. count ? cache [ globalLayerIdx] : nil
362- let fullMask = createAttentionMask ( h: hiddenStates, cache: globalCache, windowSize: nil )
363-
364- let slidingCache = cache. first ?? nil
365- let slidingMask = createAttentionMask ( h: hiddenStates, cache: slidingCache, windowSize: slidingWindow)
373+ // Determine mask behavior:
374+ // - When slidingWindowPattern > 0: alternating global/sliding (1B style)
375+ // - When slidingWindowPattern == 0: all layers use same mask type
376+ // - If slidingWindow > 0: all use sliding window (4B style)
377+ // - If slidingWindow == 0: all use full attention
378+ let useUniformSliding = slidingWindowPattern == 0 && slidingWindow > 0
379+
380+ // Create masks
381+ let firstCache = cache. first ?? nil
382+ let slidingMask = createAttentionMask ( h: hiddenStates, cache: firstCache, windowSize: slidingWindow)
383+
384+ // Only create full mask if we have alternating pattern
385+ let fullMask : MLXFast . ScaledDotProductAttentionMaskMode
386+ if slidingWindowPattern > 0 {
387+ let globalLayerIdx = slidingWindowPattern - 1
388+ let globalCache = globalLayerIdx < cache. count ? cache [ globalLayerIdx] : nil
389+ fullMask = createAttentionMask ( h: hiddenStates, cache: globalCache, windowSize: nil )
390+ } else {
391+ fullMask = slidingMask // Not used when uniform sliding
392+ }
366393
367394 for i in 0 ..< layers. count {
368- let isGlobal = layers [ i] . isGlobal
369- let mask = isGlobal ? fullMask : slidingMask
395+ let mask : MLXFast . ScaledDotProductAttentionMaskMode
396+ if useUniformSliding {
397+ // 4B style: all layers use sliding window
398+ mask = slidingMask
399+ } else if layers [ i] . isGlobal {
400+ // 1B style: global layers use full attention
401+ mask = fullMask
402+ } else {
403+ // 1B style: non-global layers use sliding window
404+ mask = slidingMask
405+ }
370406 hiddenStates = layers [ i] ( hiddenStates, mask: mask, cache: & cache[ i] )
371407 }
372408
0 commit comments