@@ -30,25 +30,25 @@ public struct Gemma3nConfiguration: Decodable, Sendable {
3030 public var queryPreAttnScalar : Int ?
3131 public var layerTypes : [ String ] ?
3232 public var slidingWindow : Int ?
33-
33+
3434 // Gemma3n specific
3535 public var hiddenSizePerLayerInput : Int ?
3636 public var vocabSizePerLayerInput : Int ?
3737 public var altupNumInputs : Int ?
3838 public var altupActiveIdx : Int ?
3939 public var altupCorrectScale : Bool ?
4040 public var laurelRank : Int ?
41-
41+
4242 public var modelType : String ?
43-
43+
4444 /// Get intermediate size for a specific layer
4545 public func intermediateSize( forLayer layer: Int ) -> Int {
4646 if layer < intermediateSizes. count {
4747 return intermediateSizes [ layer]
4848 }
4949 return intermediateSizes. first ?? 8192
5050 }
51-
51+
5252 enum CodingKeys : String , CodingKey {
5353 case textConfig = " text_config "
5454 case hiddenSize = " hidden_size "
@@ -73,10 +73,10 @@ public struct Gemma3nConfiguration: Decodable, Sendable {
7373 case laurelRank = " laurel_rank "
7474 case modelType = " model_type "
7575 }
76-
76+
7777 public init ( from decoder: Swift . Decoder ) throws {
7878 let container = try decoder. container ( keyedBy: CodingKeys . self)
79-
79+
8080 // Try to decode from text_config first (VLM format), then from top level
8181 if let textContainer = try ? container. nestedContainer ( keyedBy: CodingKeys . self, forKey: . textConfig) {
8282 hiddenSize = try textContainer. decode ( Int . self, forKey: . hiddenSize)
@@ -147,28 +147,28 @@ public struct Gemma3nConfiguration: Decodable, Sendable {
147147class Gemma3nTextDecoderLayer : Module {
148148 let layerIdx : Int
149149 let attentionType : String
150-
150+
151151 @ModuleInfo ( key: " self_attn " ) var selfAttn : Gemma3nTextAttention
152152 @ModuleInfo ( key: " mlp " ) var mlp : Gemma3nTextMLP
153153 @ModuleInfo ( key: " input_layernorm " ) var inputLayernorm : Gemma3nRMSNorm
154154 @ModuleInfo ( key: " post_attention_layernorm " ) var postAttentionLayernorm : Gemma3nRMSNorm
155155 @ModuleInfo ( key: " pre_feedforward_layernorm " ) var preFeedforwardLayernorm : Gemma3nRMSNorm
156156 @ModuleInfo ( key: " post_feedforward_layernorm " ) var postFeedforwardLayernorm : Gemma3nRMSNorm
157-
157+
158158 init ( _ config: Gemma3nConfiguration , layerIdx: Int ) {
159159 self . layerIdx = layerIdx
160-
160+
161161 // Determine attention type for this layer
162162 if let layerTypes = config. layerTypes, layerIdx < layerTypes. count {
163163 self . attentionType = layerTypes [ layerIdx]
164164 } else {
165165 self . attentionType = " full_attention "
166166 }
167-
167+
168168 let eps = config. rmsNormEps ?? 1e-6
169169 let numKVHeads = config. numKeyValueHeads ?? config. numAttentionHeads
170170 let intermediateSize = config. intermediateSize ( forLayer: layerIdx)
171-
171+
172172 // Initialize attention
173173 self . _selfAttn. wrappedValue = Gemma3nTextAttention (
174174 hiddenSize: config. hiddenSize,
@@ -178,20 +178,20 @@ class Gemma3nTextDecoderLayer: Module {
178178 queryPreAttnScalar: config. queryPreAttnScalar,
179179 eps: eps
180180 )
181-
181+
182182 // Initialize MLP
183183 self . _mlp. wrappedValue = Gemma3nTextMLP (
184184 hiddenSize: config. hiddenSize,
185185 intermediateSize: intermediateSize
186186 )
187-
187+
188188 // Initialize norms
189189 self . _inputLayernorm. wrappedValue = Gemma3nRMSNorm ( dimensions: config. hiddenSize, eps: eps)
190190 self . _postAttentionLayernorm. wrappedValue = Gemma3nRMSNorm ( dimensions: config. hiddenSize, eps: eps)
191191 self . _preFeedforwardLayernorm. wrappedValue = Gemma3nRMSNorm ( dimensions: config. hiddenSize, eps: eps)
192192 self . _postFeedforwardLayernorm. wrappedValue = Gemma3nRMSNorm ( dimensions: config. hiddenSize, eps: eps)
193193 }
194-
194+
195195 /// Simplified forward pass (standard transformer without AltUp/Laurel)
196196 func callAsFunction(
197197 _ hiddenStates: MLXArray ,
@@ -204,13 +204,13 @@ class Gemma3nTextDecoderLayer: Module {
204204 let attnOut = selfAttn ( normed, positionEmbeddings: positionEmbeddings, mask: mask, cache: cache)
205205 let attnNormed = postAttentionLayernorm ( attnOut)
206206 var h = hiddenStates + attnNormed
207-
207+
208208 // 2. Pre-norm + MLP
209209 let mlpIn = preFeedforwardLayernorm ( h)
210210 let mlpOut = mlp ( mlpIn)
211211 let mlpNormed = postFeedforwardLayernorm ( mlpOut)
212212 h = h + mlpNormed
213-
213+
214214 return h
215215 }
216216}
@@ -221,33 +221,33 @@ class Gemma3nTextDecoderLayer: Module {
221221class Gemma3nTextModelInner : Module {
222222 let numLayers : Int
223223 let hiddenSize : Int
224-
224+
225225 @ModuleInfo ( key: " embed_tokens " ) var embedTokens : Gemma3nTextScaledWordEmbedding
226226 @ModuleInfo ( key: " layers " ) var layers : [ Gemma3nTextDecoderLayer ]
227227 @ModuleInfo ( key: " norm " ) var norm : Gemma3nRMSNorm
228228 @ModuleInfo ( key: " rotary_emb " ) var rotaryEmb : Gemma3nRotaryEmbedding
229-
229+
230230 init ( _ config: Gemma3nConfiguration ) {
231231 self . numLayers = config. numHiddenLayers
232232 self . hiddenSize = config. hiddenSize
233-
233+
234234 let eps = config. rmsNormEps ?? 1e-6
235-
235+
236236 // Main token embedding
237237 self . _embedTokens. wrappedValue = Gemma3nTextScaledWordEmbedding (
238238 embeddingCount: config. vocabSize,
239239 dimensions: config. hiddenSize,
240240 embedScale: sqrt ( Float ( config. hiddenSize) )
241241 )
242-
242+
243243 // Decoder layers
244244 self . _layers. wrappedValue = ( 0 ..< numLayers) . map { idx in
245245 Gemma3nTextDecoderLayer ( config, layerIdx: idx)
246246 }
247-
247+
248248 // Final norm
249249 self . _norm. wrappedValue = Gemma3nRMSNorm ( dimensions: config. hiddenSize, eps: eps)
250-
250+
251251 // Rotary embedding
252252 self . _rotaryEmb. wrappedValue = Gemma3nRotaryEmbedding (
253253 dim: config. headDim,
@@ -256,26 +256,26 @@ class Gemma3nTextModelInner: Module {
256256 ropeLocalBaseFreq: config. ropeLocalBaseFreq ?? 10000.0
257257 )
258258 }
259-
259+
260260 func callAsFunction( _ inputIds: MLXArray , cache: [ [ KVCache ] ] ? = nil ) -> MLXArray {
261261 // 1. Embed tokens
262262 var hiddenStates = embedTokens ( inputIds)
263-
263+
264264 // 2. Compute position embeddings
265265 let seqLen = inputIds. dim ( 1 )
266266 let positions = MLXArray ( Array ( 0 ..< seqLen) . map { Int32 ( $0) } )
267-
267+
268268 // 3. Process through layers
269269 for (layerIdx, layer) in layers. enumerated ( ) {
270270 let positionEmbeddings = rotaryEmb ( positions, layerType: layer. attentionType)
271271 let layerCache = cache ? [ layerIdx] . first
272-
272+
273273 hiddenStates = layer ( hiddenStates, positionEmbeddings: positionEmbeddings, mask: nil , cache: layerCache)
274274 }
275-
275+
276276 // 4. Final norm
277277 hiddenStates = norm ( hiddenStates)
278-
278+
279279 return hiddenStates
280280 }
281281}
@@ -285,29 +285,29 @@ class Gemma3nTextModelInner: Module {
285285public class Gemma3nModel : Module , LLMModel {
286286 public let vocabularySize : Int
287287 public let numLayers : Int
288-
288+
289289 @ModuleInfo ( key: " model " ) var model : Gemma3nTextModelInner
290290 @ModuleInfo ( key: " lm_head " ) var lmHead : Linear
291291 private let config : Gemma3nConfiguration
292-
292+
293293 public init ( _ config: Gemma3nConfiguration ) {
294294 self . config = config
295295 self . vocabularySize = config. vocabSize
296296 self . numLayers = config. numHiddenLayers
297-
297+
298298 self . _model. wrappedValue = Gemma3nTextModelInner ( config)
299299 self . _lmHead. wrappedValue = Linear ( config. hiddenSize, config. vocabSize, bias: false )
300300 }
301-
301+
302302 public func callAsFunction( _ inputIds: MLXArray ) -> MLXArray {
303303 let h = model ( inputIds, cache: nil )
304304 return lmHead ( h)
305305 }
306-
306+
307307 /// Sanitize weight keys from HuggingFace format
308308 public func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
309309 var result : [ String : MLXArray ] = [ : ]
310-
310+
311311 // Keys to skip (complex Gemma3n-specific modules that we don't use in simplified forward)
312312 let skipPatterns = [
313313 " altup " ,
@@ -321,28 +321,28 @@ public class Gemma3nModel: Module, LLMModel {
321321 " altup_projections " ,
322322 " altup_unembed_projections " ,
323323 ]
324-
324+
325325 for (key, value) in weights {
326326 var newKey = key
327-
327+
328328 // model.language_model.X -> model.X
329329 if newKey. hasPrefix ( " model.language_model. " ) {
330330 newKey = " model. " + String( newKey. dropFirst ( " model.language_model. " . count) )
331331 } else if newKey. hasPrefix ( " language_model. " ) {
332332 newKey = " model. " + String( newKey. dropFirst ( " language_model. " . count) )
333333 }
334-
334+
335335 // Skip complex modules we don't use
336336 let shouldSkip = skipPatterns. contains { newKey. contains ( $0) }
337337 if shouldSkip {
338338 continue
339339 }
340-
340+
341341 // Remap embed_tokens.X -> embed_tokens.inner.X (for our wrapper structure)
342342 if newKey. contains ( " embed_tokens. " ) && !newKey. contains ( " embed_tokens.inner. " ) {
343343 newKey = newKey. replacingOccurrences ( of: " embed_tokens. " , with: " embed_tokens.inner. " )
344344 }
345-
345+
346346 result [ newKey] = value
347347 }
348348 return result
0 commit comments