Skip to content
This repository was archived by the owner on Jan 16, 2026. It is now read-only.

Commit 6f19a8d

Browse files
committed
feat(gemma3): improve 4B config inference from weight shapes
- Infer num_attention_heads/num_key_value_heads from hidden_size - 1B: 4 heads, 1 kv_head - 4B: 8 heads, 4 kv_heads - 12B: 16 heads, 8 kv_heads - 27B: 32 heads, 16 kv_heads - Fix vocab_size default to 262208 (includes VLM special tokens) - Improve mask selection for uniform sliding window (4B style) - Fix rope_theta default to 1000000 (same for all Gemma 3 models) Note: 4B still produces gibberish - needs further investigation.
1 parent 52e4979 commit 6f19a8d

File tree

1 file changed

+66
-30
lines changed

1 file changed

+66
-30
lines changed

packages/swift/Sources/NodeMLXCore/Models/Gemma3.swift

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -100,37 +100,52 @@ public struct Gemma3Configuration: Decodable, Sendable {
100100
// Gemma 3 models have consistent head_dim of 256
101101
headDim = getOptionalValue(.headDim, type: Int.self) ?? 256
102102

103-
// num_attention_heads might not be in VLM configs - calculate from hidden_size/head_dim
104-
// Gemma 3 sizes: 1B=4 heads, 4B=10 heads, 12B=16 heads, 27B=32 heads
103+
// num_attention_heads and num_key_value_heads depend on model size
104+
// VLM configs often don't include these, so we infer from hidden_size
105+
// Known Gemma 3 configurations:
106+
// 1B: hidden=1152, heads=4, kv_heads=1
107+
// 4B: hidden=2560, heads=8, kv_heads=4
108+
// 12B: hidden=3840, heads=16, kv_heads=8
109+
// 27B: hidden=5120, heads=32, kv_heads=16
105110
if let heads = getOptionalValue(.numAttentionHeads, type: Int.self) {
106111
numAttentionHeads = heads
107112
} else {
108-
numAttentionHeads = hiddenSize / headDim
113+
// Infer from hidden_size for VLM configs
114+
switch hiddenSize {
115+
case 1152: numAttentionHeads = 4
116+
case 2560: numAttentionHeads = 8
117+
case 3840: numAttentionHeads = 16
118+
case 5120: numAttentionHeads = 32
119+
default: numAttentionHeads = hiddenSize / headDim // Fallback
120+
}
109121
}
110122

111-
// num_key_value_heads defaults to 1 for Gemma 3 (extreme GQA)
112-
numKeyValueHeads = getOptionalValue(.numKeyValueHeads, type: Int.self) ?? 1
123+
if let kvHeads = getOptionalValue(.numKeyValueHeads, type: Int.self) {
124+
numKeyValueHeads = kvHeads
125+
} else {
126+
// Infer from hidden_size for VLM configs
127+
switch hiddenSize {
128+
case 1152: numKeyValueHeads = 1
129+
case 2560: numKeyValueHeads = 4
130+
case 3840: numKeyValueHeads = 8
131+
case 5120: numKeyValueHeads = 16
132+
default: numKeyValueHeads = max(1, numAttentionHeads / 4) // Fallback
133+
}
134+
}
113135

114-
// vocab_size for Gemma 3 is 262144
115-
vocabSize = getOptionalValue(.vocabSize, type: Int.self) ?? 262144
136+
// vocab_size for Gemma 3 is 262208 (includes special tokens for VLM)
137+
vocabSize = getOptionalValue(.vocabSize, type: Int.self) ?? 262208
116138

117139
rmsNormEps = getOptionalValue(.rmsNormEps, type: Float.self) ?? 1e-6
140+
141+
// Gemma 3 RoPE theta defaults from HuggingFace transformers:
142+
// - Global attention: 1,000,000
143+
// - Sliding/local attention: 10,000
144+
ropeTheta = getOptionalValue(.ropeTheta, type: Float.self) ?? 1000000.0
118145
ropeLocalTheta = getOptionalValue(.ropeLocalTheta, type: Float.self) ?? 10000.0
119146

120-
// Check for rope_scaling first to determine appropriate default
147+
// Check for rope_scaling (used in 4B+ models for extended context)
121148
ropeScaling = getOptionalValue(.ropeScaling, type: [String: StringOrNumber].self)
122-
123-
// rope_theta: 1B uses 1000000, 4B+ with linear scaling should use 10000 as base
124-
// (the linear scaling extends the effective context)
125-
if let specifiedTheta = getOptionalValue(.ropeTheta, type: Float.self) {
126-
ropeTheta = specifiedTheta
127-
} else if ropeScaling != nil {
128-
// Models with rope_scaling typically use lower base theta
129-
ropeTheta = 10000.0
130-
} else {
131-
// Gemma 3 1B default
132-
ropeTheta = 1000000.0
133-
}
134149
maxPositionEmbeddings = getOptionalValue(.maxPositionEmbeddings, type: Int.self) ?? 32768
135150
slidingWindow = getOptionalValue(.slidingWindow, type: Int.self) ?? 512
136151
// Default to 0 (no pattern = all layers same) for models like 4B that don't specify it
@@ -355,18 +370,39 @@ class Gemma3ModelInner: Module {
355370
let scale = MLXArray(sqrt(Float(hiddenSize)))
356371
hiddenStates = hiddenStates * scale.asType(hiddenStates.dtype)
357372

358-
// Create masks for global and sliding window attention
359-
// Find a global layer to get its cache offset for the full mask
360-
let globalLayerIdx = slidingWindowPattern > 0 ? slidingWindowPattern - 1 : 0
361-
let globalCache = globalLayerIdx < cache.count ? cache[globalLayerIdx] : nil
362-
let fullMask = createAttentionMask(h: hiddenStates, cache: globalCache, windowSize: nil)
363-
364-
let slidingCache = cache.first ?? nil
365-
let slidingMask = createAttentionMask(h: hiddenStates, cache: slidingCache, windowSize: slidingWindow)
373+
// Determine mask behavior:
374+
// - When slidingWindowPattern > 0: alternating global/sliding (1B style)
375+
// - When slidingWindowPattern == 0: all layers use same mask type
376+
// - If slidingWindow > 0: all use sliding window (4B style)
377+
// - If slidingWindow == 0: all use full attention
378+
let useUniformSliding = slidingWindowPattern == 0 && slidingWindow > 0
379+
380+
// Create masks
381+
let firstCache = cache.first ?? nil
382+
let slidingMask = createAttentionMask(h: hiddenStates, cache: firstCache, windowSize: slidingWindow)
383+
384+
// Only create full mask if we have alternating pattern
385+
let fullMask: MLXFast.ScaledDotProductAttentionMaskMode
386+
if slidingWindowPattern > 0 {
387+
let globalLayerIdx = slidingWindowPattern - 1
388+
let globalCache = globalLayerIdx < cache.count ? cache[globalLayerIdx] : nil
389+
fullMask = createAttentionMask(h: hiddenStates, cache: globalCache, windowSize: nil)
390+
} else {
391+
fullMask = slidingMask // Not used when uniform sliding
392+
}
366393

367394
for i in 0..<layers.count {
368-
let isGlobal = layers[i].isGlobal
369-
let mask = isGlobal ? fullMask : slidingMask
395+
let mask: MLXFast.ScaledDotProductAttentionMaskMode
396+
if useUniformSliding {
397+
// 4B style: all layers use sliding window
398+
mask = slidingMask
399+
} else if layers[i].isGlobal {
400+
// 1B style: global layers use full attention
401+
mask = fullMask
402+
} else {
403+
// 1B style: non-global layers use sliding window
404+
mask = slidingMask
405+
}
370406
hiddenStates = layers[i](hiddenStates, mask: mask, cache: &cache[i])
371407
}
372408

0 commit comments

Comments
 (0)