Skip to content

Commit a4bc800

Browse files
committed
feat: add phi-4 to recommended models
Phi-4 uses the same phi3 architecture and works out of the box. It's larger (14B params) so slower (~3 tok/s vs ~31 tok/s for Phi-3) but provides better quality responses.
1 parent 866d641 commit a4bc800

File tree

3 files changed

+45
-43
lines changed

3 files changed

+45
-43
lines changed

packages/node-mlx/src/index.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,10 @@ export const RECOMMENDED_MODELS = {
161161
"qwen-2.5-1.5b": "Qwen/Qwen2.5-1.5B-Instruct",
162162
"qwen-2.5-3b": "Qwen/Qwen2.5-3B-Instruct",
163163

164-
// Phi 3 (Microsoft) - Working with fused QKV and RoPE
165-
phi: "microsoft/Phi-3-mini-4k-instruct",
164+
// Phi (Microsoft) - Working with fused QKV and RoPE
165+
phi: "microsoft/phi-4", // Default to latest
166+
phi4: "microsoft/phi-4",
167+
"phi-4": "microsoft/phi-4",
166168
phi3: "microsoft/Phi-3-mini-4k-instruct",
167169
"phi-3": "microsoft/Phi-3-mini-4k-instruct",
168170
"phi-3-mini": "microsoft/Phi-3-mini-4k-instruct",

packages/swift/Sources/NodeMLXCore/Models/Gemma3n.swift

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,25 @@ public struct Gemma3nConfiguration: Decodable, Sendable {
3030
public var queryPreAttnScalar: Int?
3131
public var layerTypes: [String]?
3232
public var slidingWindow: Int?
33-
33+
3434
// Gemma3n specific
3535
public var hiddenSizePerLayerInput: Int?
3636
public var vocabSizePerLayerInput: Int?
3737
public var altupNumInputs: Int?
3838
public var altupActiveIdx: Int?
3939
public var altupCorrectScale: Bool?
4040
public var laurelRank: Int?
41-
41+
4242
public var modelType: String?
43-
43+
4444
/// Get intermediate size for a specific layer
4545
public func intermediateSize(forLayer layer: Int) -> Int {
4646
if layer < intermediateSizes.count {
4747
return intermediateSizes[layer]
4848
}
4949
return intermediateSizes.first ?? 8192
5050
}
51-
51+
5252
enum CodingKeys: String, CodingKey {
5353
case textConfig = "text_config"
5454
case hiddenSize = "hidden_size"
@@ -73,10 +73,10 @@ public struct Gemma3nConfiguration: Decodable, Sendable {
7373
case laurelRank = "laurel_rank"
7474
case modelType = "model_type"
7575
}
76-
76+
7777
public init(from decoder: Swift.Decoder) throws {
7878
let container = try decoder.container(keyedBy: CodingKeys.self)
79-
79+
8080
// Try to decode from text_config first (VLM format), then from top level
8181
if let textContainer = try? container.nestedContainer(keyedBy: CodingKeys.self, forKey: .textConfig) {
8282
hiddenSize = try textContainer.decode(Int.self, forKey: .hiddenSize)
@@ -147,28 +147,28 @@ public struct Gemma3nConfiguration: Decodable, Sendable {
147147
class Gemma3nTextDecoderLayer: Module {
148148
let layerIdx: Int
149149
let attentionType: String
150-
150+
151151
@ModuleInfo(key: "self_attn") var selfAttn: Gemma3nTextAttention
152152
@ModuleInfo(key: "mlp") var mlp: Gemma3nTextMLP
153153
@ModuleInfo(key: "input_layernorm") var inputLayernorm: Gemma3nRMSNorm
154154
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayernorm: Gemma3nRMSNorm
155155
@ModuleInfo(key: "pre_feedforward_layernorm") var preFeedforwardLayernorm: Gemma3nRMSNorm
156156
@ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayernorm: Gemma3nRMSNorm
157-
157+
158158
init(_ config: Gemma3nConfiguration, layerIdx: Int) {
159159
self.layerIdx = layerIdx
160-
160+
161161
// Determine attention type for this layer
162162
if let layerTypes = config.layerTypes, layerIdx < layerTypes.count {
163163
self.attentionType = layerTypes[layerIdx]
164164
} else {
165165
self.attentionType = "full_attention"
166166
}
167-
167+
168168
let eps = config.rmsNormEps ?? 1e-6
169169
let numKVHeads = config.numKeyValueHeads ?? config.numAttentionHeads
170170
let intermediateSize = config.intermediateSize(forLayer: layerIdx)
171-
171+
172172
// Initialize attention
173173
self._selfAttn.wrappedValue = Gemma3nTextAttention(
174174
hiddenSize: config.hiddenSize,
@@ -178,20 +178,20 @@ class Gemma3nTextDecoderLayer: Module {
178178
queryPreAttnScalar: config.queryPreAttnScalar,
179179
eps: eps
180180
)
181-
181+
182182
// Initialize MLP
183183
self._mlp.wrappedValue = Gemma3nTextMLP(
184184
hiddenSize: config.hiddenSize,
185185
intermediateSize: intermediateSize
186186
)
187-
187+
188188
// Initialize norms
189189
self._inputLayernorm.wrappedValue = Gemma3nRMSNorm(dimensions: config.hiddenSize, eps: eps)
190190
self._postAttentionLayernorm.wrappedValue = Gemma3nRMSNorm(dimensions: config.hiddenSize, eps: eps)
191191
self._preFeedforwardLayernorm.wrappedValue = Gemma3nRMSNorm(dimensions: config.hiddenSize, eps: eps)
192192
self._postFeedforwardLayernorm.wrappedValue = Gemma3nRMSNorm(dimensions: config.hiddenSize, eps: eps)
193193
}
194-
194+
195195
/// Simplified forward pass (standard transformer without AltUp/Laurel)
196196
func callAsFunction(
197197
_ hiddenStates: MLXArray,
@@ -204,13 +204,13 @@ class Gemma3nTextDecoderLayer: Module {
204204
let attnOut = selfAttn(normed, positionEmbeddings: positionEmbeddings, mask: mask, cache: cache)
205205
let attnNormed = postAttentionLayernorm(attnOut)
206206
var h = hiddenStates + attnNormed
207-
207+
208208
// 2. Pre-norm + MLP
209209
let mlpIn = preFeedforwardLayernorm(h)
210210
let mlpOut = mlp(mlpIn)
211211
let mlpNormed = postFeedforwardLayernorm(mlpOut)
212212
h = h + mlpNormed
213-
213+
214214
return h
215215
}
216216
}
@@ -221,33 +221,33 @@ class Gemma3nTextDecoderLayer: Module {
221221
class Gemma3nTextModelInner: Module {
222222
let numLayers: Int
223223
let hiddenSize: Int
224-
224+
225225
@ModuleInfo(key: "embed_tokens") var embedTokens: Gemma3nTextScaledWordEmbedding
226226
@ModuleInfo(key: "layers") var layers: [Gemma3nTextDecoderLayer]
227227
@ModuleInfo(key: "norm") var norm: Gemma3nRMSNorm
228228
@ModuleInfo(key: "rotary_emb") var rotaryEmb: Gemma3nRotaryEmbedding
229-
229+
230230
init(_ config: Gemma3nConfiguration) {
231231
self.numLayers = config.numHiddenLayers
232232
self.hiddenSize = config.hiddenSize
233-
233+
234234
let eps = config.rmsNormEps ?? 1e-6
235-
235+
236236
// Main token embedding
237237
self._embedTokens.wrappedValue = Gemma3nTextScaledWordEmbedding(
238238
embeddingCount: config.vocabSize,
239239
dimensions: config.hiddenSize,
240240
embedScale: sqrt(Float(config.hiddenSize))
241241
)
242-
242+
243243
// Decoder layers
244244
self._layers.wrappedValue = (0..<numLayers).map { idx in
245245
Gemma3nTextDecoderLayer(config, layerIdx: idx)
246246
}
247-
247+
248248
// Final norm
249249
self._norm.wrappedValue = Gemma3nRMSNorm(dimensions: config.hiddenSize, eps: eps)
250-
250+
251251
// Rotary embedding
252252
self._rotaryEmb.wrappedValue = Gemma3nRotaryEmbedding(
253253
dim: config.headDim,
@@ -256,26 +256,26 @@ class Gemma3nTextModelInner: Module {
256256
ropeLocalBaseFreq: config.ropeLocalBaseFreq ?? 10000.0
257257
)
258258
}
259-
259+
260260
func callAsFunction(_ inputIds: MLXArray, cache: [[KVCache]]? = nil) -> MLXArray {
261261
// 1. Embed tokens
262262
var hiddenStates = embedTokens(inputIds)
263-
263+
264264
// 2. Compute position embeddings
265265
let seqLen = inputIds.dim(1)
266266
let positions = MLXArray(Array(0..<seqLen).map { Int32($0) })
267-
267+
268268
// 3. Process through layers
269269
for (layerIdx, layer) in layers.enumerated() {
270270
let positionEmbeddings = rotaryEmb(positions, layerType: layer.attentionType)
271271
let layerCache = cache?[layerIdx].first
272-
272+
273273
hiddenStates = layer(hiddenStates, positionEmbeddings: positionEmbeddings, mask: nil, cache: layerCache)
274274
}
275-
275+
276276
// 4. Final norm
277277
hiddenStates = norm(hiddenStates)
278-
278+
279279
return hiddenStates
280280
}
281281
}
@@ -285,29 +285,29 @@ class Gemma3nTextModelInner: Module {
285285
public class Gemma3nModel: Module, LLMModel {
286286
public let vocabularySize: Int
287287
public let numLayers: Int
288-
288+
289289
@ModuleInfo(key: "model") var model: Gemma3nTextModelInner
290290
@ModuleInfo(key: "lm_head") var lmHead: Linear
291291
private let config: Gemma3nConfiguration
292-
292+
293293
public init(_ config: Gemma3nConfiguration) {
294294
self.config = config
295295
self.vocabularySize = config.vocabSize
296296
self.numLayers = config.numHiddenLayers
297-
297+
298298
self._model.wrappedValue = Gemma3nTextModelInner(config)
299299
self._lmHead.wrappedValue = Linear(config.hiddenSize, config.vocabSize, bias: false)
300300
}
301-
301+
302302
public func callAsFunction(_ inputIds: MLXArray) -> MLXArray {
303303
let h = model(inputIds, cache: nil)
304304
return lmHead(h)
305305
}
306-
306+
307307
/// Sanitize weight keys from HuggingFace format
308308
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
309309
var result: [String: MLXArray] = [:]
310-
310+
311311
// Keys to skip (complex Gemma3n-specific modules that we don't use in simplified forward)
312312
let skipPatterns = [
313313
"altup",
@@ -321,28 +321,28 @@ public class Gemma3nModel: Module, LLMModel {
321321
"altup_projections",
322322
"altup_unembed_projections",
323323
]
324-
324+
325325
for (key, value) in weights {
326326
var newKey = key
327-
327+
328328
// model.language_model.X -> model.X
329329
if newKey.hasPrefix("model.language_model.") {
330330
newKey = "model." + String(newKey.dropFirst("model.language_model.".count))
331331
} else if newKey.hasPrefix("language_model.") {
332332
newKey = "model." + String(newKey.dropFirst("language_model.".count))
333333
}
334-
334+
335335
// Skip complex modules we don't use
336336
let shouldSkip = skipPatterns.contains { newKey.contains($0) }
337337
if shouldSkip {
338338
continue
339339
}
340-
340+
341341
// Remap embed_tokens.X -> embed_tokens.inner.X (for our wrapper structure)
342342
if newKey.contains("embed_tokens.") && !newKey.contains("embed_tokens.inner.") {
343343
newKey = newKey.replacingOccurrences(of: "embed_tokens.", with: "embed_tokens.inner.")
344344
}
345-
345+
346346
result[newKey] = value
347347
}
348348
return result

packages/swift/Sources/NodeMLXCore/Models/Phi3.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,4 +353,4 @@ public class Phi3Model: Module, LLMModel {
353353
// Override in subclass if weight key mapping needed
354354
return weights
355355
}
356-
}
356+
}

0 commit comments

Comments
 (0)