Skip to content
This repository was archived by the owner on Jan 16, 2026. It is now read-only.

Commit 934822b

Browse files
committed
feat(gemma3): add linear RoPE scaling for 4B+ models
- Add RoPEProvider protocol for polymorphic RoPE usage - Parse rope_scaling from config for linear scaling - Fix slidingWindowPattern default to 0 for VLM configs - Fix cache type selection for models without pattern - Fix CLI argument parsing (model then prompt) Note: 4B model still produces gibberish, needs further debugging.
1 parent 5df1701 commit 934822b

File tree

4 files changed

+68
-19
lines changed

4 files changed

+68
-19
lines changed

packages/node-mlx/src/cli.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,13 @@ function parseArgs(): {
434434
} else if (arg === "--tokens" || arg === "-n") {
435435
options.maxTokens = parseInt(args[++i] || "512", 10)
436436
} else if (!arg.startsWith("-")) {
437-
prompt = arg
438-
command = "oneshot"
437+
// First positional arg is model, second is prompt
438+
if (model === "qwen") {
439+
model = arg
440+
} else if (prompt === null) {
441+
prompt = arg
442+
command = "oneshot"
443+
}
439444
}
440445
}
441446

packages/node-mlx/src/index.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,10 @@ export const RECOMMENDED_MODELS = {
180180
gemma: "mlx-community/gemma-3-1b-it-4bit",
181181
"gemma-3": "mlx-community/gemma-3-1b-it-4bit",
182182
"gemma-3-1b": "mlx-community/gemma-3-1b-it-4bit",
183-
"gemma-3-1b-bf16": "mlx-community/gemma-3-1b-it-bf16"
183+
"gemma-3-1b-bf16": "mlx-community/gemma-3-1b-it-bf16",
184+
"gemma-3-4b": "mlx-community/gemma-3-4b-it-4bit"
184185

185186
// TODO: These models need fixes:
186-
// - Gemma 3 4B+: Linear RoPE scaling (factor 8.0)
187187
// - Gemma3n: Complex AltUp/Laurel architecture (VLM with audio)
188188
// - Mistral: GQA head count compatibility
189189
} as const

packages/swift/Sources/NodeMLXCore/Models/Gemma3.swift

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,19 @@ public struct Gemma3Configuration: Decodable, Sendable {
2828
public var maxPositionEmbeddings: Int
2929
public var slidingWindow: Int // Window size for sliding attention layers
3030
public var slidingWindowPattern: Int // Every Nth layer is global (full) attention
31+
public var ropeScaling: [String: StringOrNumber]? // For linear scaling in 4B+ models
3132

3233
public var modelType: String?
3334

3435
/// Check if a layer is a global attention layer
3536
public func isGlobalLayer(_ layerIdx: Int) -> Bool {
3637
// Every Nth layer is global, where N = slidingWindowPattern
3738
// Pattern starts from layer 0, so layer 5 (6th) is global when pattern=6
38-
return slidingWindowPattern > 0 && (layerIdx + 1) % slidingWindowPattern == 0
39+
// When slidingWindowPattern is 0, treat all layers as global (applies to 4B+ models)
40+
if slidingWindowPattern == 0 {
41+
return true
42+
}
43+
return (layerIdx + 1) % slidingWindowPattern == 0
3944
}
4045

4146
enum CodingKeys: String, CodingKey {
@@ -53,6 +58,7 @@ public struct Gemma3Configuration: Decodable, Sendable {
5358
case maxPositionEmbeddings = "max_position_embeddings"
5459
case slidingWindow = "sliding_window"
5560
case slidingWindowPattern = "sliding_window_pattern"
61+
case ropeScaling = "rope_scaling"
5662
case modelType = "model_type"
5763
}
5864

@@ -113,7 +119,9 @@ public struct Gemma3Configuration: Decodable, Sendable {
113119
ropeLocalTheta = getOptionalValue(.ropeLocalTheta, type: Float.self) ?? 10000.0
114120
maxPositionEmbeddings = getOptionalValue(.maxPositionEmbeddings, type: Int.self) ?? 32768
115121
slidingWindow = getOptionalValue(.slidingWindow, type: Int.self) ?? 512
116-
slidingWindowPattern = getOptionalValue(.slidingWindowPattern, type: Int.self) ?? 6
122+
// Default to 0 (no pattern = all layers same) for models like 4B that don't specify it
123+
slidingWindowPattern = getOptionalValue(.slidingWindowPattern, type: Int.self) ?? 0
124+
ropeScaling = getOptionalValue(.ropeScaling, type: [String: StringOrNumber].self)
117125
modelType = getOptionalValue(.modelType, type: String.self)
118126
}
119127
}
@@ -152,7 +160,7 @@ class Gemma3Attention: Module {
152160
let numKVHeads: Int
153161
let headDim: Int
154162
let scale: Float
155-
let rope: RoPE
163+
let rope: any RoPEProvider // Can be RoPE or other RoPE variants
156164
let isGlobal: Bool
157165
let slidingWindow: Int?
158166

@@ -177,8 +185,17 @@ class Gemma3Attention: Module {
177185
self._kNorm.wrappedValue = Gemma3RMSNorm(dimensions: headDim, eps: config.rmsNormEps)
178186

179187
// Different RoPE theta for sliding vs global layers
188+
// For global layers with rope_scaling, use the scaling config
180189
let ropeBase = isGlobal ? config.ropeTheta : config.ropeLocalTheta
181-
self.rope = RoPE(dimensions: headDim, traditional: false, base: ropeBase)
190+
191+
// Use initializeRope to handle linear scaling for 4B+ models
192+
self.rope = initializeRope(
193+
dims: headDim,
194+
base: ropeBase,
195+
traditional: false,
196+
scalingConfig: isGlobal ? config.ropeScaling : nil, // Only global layers use scaling
197+
maxPositionEmbeddings: config.maxPositionEmbeddings
198+
)
182199
}
183200

184201
func callAsFunction(
@@ -204,8 +221,8 @@ class Gemma3Attention: Module {
204221

205222
// Apply RoPE with cache offset
206223
let offset = cache?.offset ?? 0
207-
queries = rope(queries, offset: offset)
208-
keys = rope(keys, offset: offset)
224+
queries = rope.apply(queries, offset: offset)
225+
keys = rope.apply(keys, offset: offset)
209226

210227
// Update cache
211228
if let c = cache {
@@ -397,12 +414,14 @@ public class Gemma3Model: Module, LLMModel {
397414
/// Create a new KV cache with appropriate cache types per layer
398415
public func newCache() -> [KVCache] {
399416
return (0..<numLayers).map { i in
400-
if config.isGlobalLayer(i) {
401-
// Global layers use standard cache
402-
return KVCacheSimple()
403-
} else {
404-
// Sliding window layers use rotating cache
417+
// When slidingWindowPattern is 0, all layers use sliding window
418+
// When slidingWindowPattern > 0, only non-global layers use sliding window
419+
let useSlidingWindow = config.slidingWindowPattern == 0 || !config.isGlobalLayer(i)
420+
421+
if useSlidingWindow && config.slidingWindow > 0 {
405422
return RotatingKVCache(maxSize: config.slidingWindow, keep: 0)
423+
} else {
424+
return KVCacheSimple()
406425
}
407426
}
408427
}

packages/swift/Sources/NodeMLXCore/RoPEUtils.swift

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,22 @@ import MLX
66
import MLXFast
77
import MLXNN
88

9+
// MARK: - RoPE Protocol
10+
11+
/// Protocol for all RoPE variants to enable polymorphic usage
12+
public protocol RoPEProvider {
13+
func apply(_ x: MLXArray, offset: Int) -> MLXArray
14+
}
15+
16+
extension RoPE: RoPEProvider {
17+
public func apply(_ x: MLXArray, offset: Int) -> MLXArray {
18+
callAsFunction(x, offset: offset)
19+
}
20+
}
21+
922
// MARK: - Llama3RoPE
1023

11-
public class Llama3RoPE: Module {
24+
public class Llama3RoPE: Module, RoPEProvider {
1225
let dims: Int
1326
let maxPositionEmbeddings: Int
1427
let traditional: Bool
@@ -71,11 +84,15 @@ public class Llama3RoPE: Module {
7184
freqs: freqs
7285
)
7386
}
87+
88+
public func apply(_ x: MLXArray, offset: Int) -> MLXArray {
89+
callAsFunction(x, offset: offset)
90+
}
7491
}
7592

7693
// MARK: - YarnRoPE
7794

78-
public class YarnRoPE: Module {
95+
public class YarnRoPE: Module, RoPEProvider {
7996
let dimensions: Int
8097
let traditional: Bool
8198
let maxPositionEmbeddings: Int
@@ -183,11 +200,15 @@ public class YarnRoPE: Module {
183200
freqs: self._freqs
184201
)
185202
}
203+
204+
public func apply(_ x: MLXArray, offset: Int) -> MLXArray {
205+
callAsFunction(x, offset: offset)
206+
}
186207
}
187208

188209
// MARK: - SuScaledRoPE (for longrope)
189210

190-
public class SuScaledRoPE: Module {
211+
public class SuScaledRoPE: Module, RoPEProvider {
191212
let dimensions: Int
192213
let base: Float
193214
let maxPositionEmbeddings: Int
@@ -259,6 +280,10 @@ public class SuScaledRoPE: Module {
259280
freqs: freqs
260281
)
261282
}
283+
284+
public func apply(_ x: MLXArray, offset: Int) -> MLXArray {
285+
callAsFunction(x, offset: offset)
286+
}
262287
}
263288

264289
// MARK: - RoPE Factory
@@ -270,7 +295,7 @@ public func initializeRope(
270295
traditional: Bool,
271296
scalingConfig: [String: StringOrNumber]?,
272297
maxPositionEmbeddings: Int?
273-
) -> Module {
298+
) -> any RoPEProvider {
274299
let ropeType: String = {
275300
if let config = scalingConfig,
276301
let typeValue = config["type"] ?? config["rope_type"],

0 commit comments

Comments
 (0)