Skip to content

Commit 3f1a79b

Browse files
authored
Merge branch 'ml-explore:main' into main
2 parents 80cefa4 + bc3c20e commit 3f1a79b

File tree

5 files changed

+1719
-15
lines changed

5 files changed

+1719
-15
lines changed

Libraries/MLXLMCommon/Evaluate.swift

Lines changed: 246 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,30 @@ public struct GenerateParameters: Sendable {
7878
/// top p sampling
7979
public var topP: Float
8080

81+
/// top k sampling (0 disables)
82+
public var topK: Int
83+
84+
/// min p sampling threshold relative to the highest probability token (0 disables)
85+
public var minP: Float
86+
8187
/// penalty factor for repeating tokens
8288
public var repetitionPenalty: Float?
8389

8490
/// number of tokens to consider for repetition penalty
8591
public var repetitionContextSize: Int
8692

93+
/// additive penalty for tokens that appear in recent context
94+
public var presencePenalty: Float?
95+
96+
/// number of tokens to consider for presence penalty
97+
public var presenceContextSize: Int
98+
99+
/// additive penalty that scales with token frequency in recent context
100+
public var frequencyPenalty: Float?
101+
102+
/// number of tokens to consider for frequency penalty
103+
public var frequencyContextSize: Int
104+
87105
public init(
88106
maxTokens: Int? = nil,
89107
maxKVSize: Int? = nil,
@@ -92,8 +110,14 @@ public struct GenerateParameters: Sendable {
92110
quantizedKVStart: Int = 0,
93111
temperature: Float = 0.6,
94112
topP: Float = 1.0,
113+
topK: Int = 0,
114+
minP: Float = 0.0,
95115
repetitionPenalty: Float? = nil,
96116
repetitionContextSize: Int = 20,
117+
presencePenalty: Float? = nil,
118+
presenceContextSize: Int = 20,
119+
frequencyPenalty: Float? = nil,
120+
frequencyContextSize: Int = 20,
97121
prefillStepSize: Int = 512
98122
) {
99123
self.maxTokens = maxTokens
@@ -103,28 +127,71 @@ public struct GenerateParameters: Sendable {
103127
self.quantizedKVStart = quantizedKVStart
104128
self.temperature = temperature
105129
self.topP = topP
130+
self.topK = topK
131+
self.minP = minP
106132
self.repetitionPenalty = repetitionPenalty
107133
self.repetitionContextSize = repetitionContextSize
134+
self.presencePenalty = presencePenalty
135+
self.presenceContextSize = presenceContextSize
136+
self.frequencyPenalty = frequencyPenalty
137+
self.frequencyContextSize = frequencyContextSize
108138
self.prefillStepSize = prefillStepSize
109139
}
110140

111141
public func sampler() -> LogitSampler {
142+
let usesTopP = topP > 0 && topP < 1
143+
let usesTopK = topK > 0
144+
let usesMinP = minP > 0
145+
112146
if temperature == 0 {
113147
return ArgMaxSampler()
114-
} else if topP > 0 && topP < 1 {
115-
return TopPSampler(temperature: temperature, topP: topP)
148+
} else if usesTopP || usesTopK || usesMinP {
149+
return TopPSampler(temperature: temperature, topP: topP, topK: topK, minP: minP)
116150
} else {
117151
return CategoricalSampler(temperature: temperature)
118152
}
119153
}
120154

121155
public func processor() -> LogitProcessor? {
122-
if let repetitionPenalty, repetitionContextSize > 0 {
123-
return RepetitionContext(
124-
repetitionPenalty: repetitionPenalty, repetitionContextSize: repetitionContextSize)
156+
let repetitionContext: RepetitionContext?
157+
if let repetitionPenalty, repetitionPenalty != 0, repetitionContextSize > 0 {
158+
repetitionContext = RepetitionContext(
159+
repetitionPenalty: repetitionPenalty,
160+
repetitionContextSize: repetitionContextSize
161+
)
162+
} else {
163+
repetitionContext = nil
164+
}
165+
166+
let presenceContext: PresencePenaltyContext?
167+
if let presencePenalty, presencePenalty != 0, presenceContextSize > 0 {
168+
presenceContext = PresencePenaltyContext(
169+
presencePenalty: presencePenalty,
170+
presenceContextSize: presenceContextSize
171+
)
172+
} else {
173+
presenceContext = nil
174+
}
175+
176+
let frequencyContext: FrequencyPenaltyContext?
177+
if let frequencyPenalty, frequencyPenalty != 0, frequencyContextSize > 0 {
178+
frequencyContext = FrequencyPenaltyContext(
179+
frequencyPenalty: frequencyPenalty,
180+
frequencyContextSize: frequencyContextSize
181+
)
125182
} else {
183+
frequencyContext = nil
184+
}
185+
186+
if repetitionContext == nil && presenceContext == nil && frequencyContext == nil {
126187
return nil
127188
}
189+
190+
return PenaltyProcessor(
191+
repetitionContext: repetitionContext,
192+
presenceContext: presenceContext,
193+
frequencyContext: frequencyContext
194+
)
128195
}
129196
}
130197

@@ -137,15 +204,24 @@ public struct ArgMaxSampler: LogitSampler {
137204
}
138205
}
139206

140-
/// Sampler that uses `topP` and `temperature` to sample the logits.
207+
/// Sampler that uses probability filters (`topP`, `topK`, `minP`) and `temperature`
208+
/// to sample the logits.
141209
public struct TopPSampler: LogitSampler {
142210
let temp: MLXArray
143-
let topP: MLXArray
211+
let topP: MLXArray?
212+
let topK: Int?
213+
let minP: MLXArray?
144214
let randomState: MLXRandom.RandomState
145215

146-
public init(temperature: Float, topP: Float) {
216+
public init(temperature: Float, topP: Float = 1.0, topK: Int = 0, minP: Float = 0.0) {
147217
self.temp = MLXArray(temperature)
148-
self.topP = MLXArray(topP)
218+
if topP > 0 && topP < 1 {
219+
self.topP = MLXArray(topP)
220+
} else {
221+
self.topP = nil
222+
}
223+
self.topK = topK > 0 ? topK : nil
224+
self.minP = minP > 0 ? MLXArray(minP) : nil
149225
self.randomState = MLXRandom.RandomState()
150226
}
151227

@@ -156,18 +232,43 @@ public struct TopPSampler: LogitSampler {
156232
}
157233

158234
return withRandomState(randomState) {
159-
let probs = softmax(logits / temp, axis: -1)
235+
// Match mlx-lm Python behavior:
236+
// apply filtering on the base distribution, then apply temperature at sampling time.
237+
let probs = softmax(logits, axis: -1)
160238
let sortedIndices = argSort(probs, axis: -1)
161239

162240
// probs shape is [B,V] and after take it will be [1, B, V], so we squeeze it back to [B, V]
163241
let sortedProbs = take(probs, sortedIndices, axis: -1).squeezed(axis: 0)
164242

165-
let cumulativeProbs = cumsum(sortedProbs, axis: -1)
243+
var filteredProbs = sortedProbs
244+
245+
if let topP {
246+
let cumulativeProbs = cumsum(sortedProbs, axis: -1)
247+
filteredProbs = MLX.where(
248+
cumulativeProbs .> (1 - topP), filteredProbs, zeros(like: filteredProbs))
249+
}
250+
251+
if let minP {
252+
let maxProbs = sortedProbs[0..., -1].expandedDimensions(axis: -1)
253+
let keepMask = sortedProbs .>= (maxProbs * minP)
254+
filteredProbs = MLX.where(keepMask, filteredProbs, zeros(like: filteredProbs))
255+
}
256+
257+
if let topK {
258+
let vocabularySize = sortedProbs.dim(-1)
259+
if topK < vocabularySize {
260+
let cutOff = vocabularySize - topK
261+
let sortedPositions = MLXArray(Array(0 ..< vocabularySize))
262+
let keepMask = sortedPositions .>= cutOff
263+
filteredProbs = MLX.where(
264+
keepMask, filteredProbs, zeros(like: filteredProbs))
265+
}
266+
}
166267

167-
let topProbs = MLX.where(
168-
cumulativeProbs .> (1 - topP), sortedProbs, zeros(like: sortedProbs))
268+
// Always keep the maximum-probability token so sampling always has a valid candidate.
269+
filteredProbs[0..., -1] = sortedProbs[0..., -1]
169270

170-
let sortedToken = categorical(log(topProbs))
271+
let sortedToken = categorical(log(filteredProbs) * (1 / temp))
171272
return sortedIndices.squeezed(axis: 0)[sortedToken]
172273
}
173274
}
@@ -244,6 +345,137 @@ public struct RepetitionContext: LogitProcessor {
244345
}
245346
}
246347

348+
/// Processor that applies an additive presence penalty to tokens in a recent context window.
349+
public struct PresencePenaltyContext: LogitProcessor {
350+
var tokens = [Int]()
351+
var index = 0
352+
353+
let presencePenalty: Float
354+
let presenceContextSize: Int
355+
356+
public init(presencePenalty: Float, presenceContextSize: Int) {
357+
precondition(presenceContextSize > 0)
358+
self.presencePenalty = presencePenalty
359+
self.presenceContextSize = presenceContextSize
360+
}
361+
362+
mutating public func prompt(_ prompt: MLXArray) {
363+
if prompt.shape[0] <= presenceContextSize {
364+
self.tokens = prompt.asArray(Int.self)
365+
} else {
366+
self.tokens = prompt[(-presenceContextSize)...].asArray(Int.self)
367+
}
368+
}
369+
370+
public func process(logits: MLXArray) -> MLXArray {
371+
if tokens.isEmpty {
372+
return logits
373+
}
374+
375+
let uniqueTokens = Array(Set(tokens))
376+
let indices = MLXArray(uniqueTokens.map { UInt32($0) })
377+
logits[0..., indices] = logits[0..., indices] - presencePenalty
378+
return logits
379+
}
380+
381+
mutating public func didSample(token: MLXArray) {
382+
if tokens.count >= presenceContextSize {
383+
tokens[index] = token.item(Int.self)
384+
index = (index + 1) % presenceContextSize
385+
} else {
386+
tokens.append(token.item(Int.self))
387+
}
388+
}
389+
}
390+
391+
/// Processor that applies an additive frequency penalty to tokens in a recent context window.
392+
public struct FrequencyPenaltyContext: LogitProcessor {
393+
var tokens = [Int]()
394+
var index = 0
395+
396+
let frequencyPenalty: Float
397+
let frequencyContextSize: Int
398+
399+
public init(frequencyPenalty: Float, frequencyContextSize: Int) {
400+
precondition(frequencyContextSize > 0)
401+
self.frequencyPenalty = frequencyPenalty
402+
self.frequencyContextSize = frequencyContextSize
403+
}
404+
405+
mutating public func prompt(_ prompt: MLXArray) {
406+
if prompt.shape[0] <= frequencyContextSize {
407+
self.tokens = prompt.asArray(Int.self)
408+
} else {
409+
self.tokens = prompt[(-frequencyContextSize)...].asArray(Int.self)
410+
}
411+
}
412+
413+
public func process(logits: MLXArray) -> MLXArray {
414+
if tokens.isEmpty {
415+
return logits
416+
}
417+
418+
var counts = [Int: Int]()
419+
for token in tokens {
420+
counts[token, default: 0] += 1
421+
}
422+
423+
let orderedTokens = Array(counts.keys)
424+
let indices = MLXArray(orderedTokens.map { UInt32($0) })
425+
let penalties = MLXArray(
426+
orderedTokens.map { frequencyPenalty * Float(counts[$0] ?? 0) }
427+
)
428+
logits[0..., indices] = logits[0..., indices] - penalties
429+
return logits
430+
}
431+
432+
mutating public func didSample(token: MLXArray) {
433+
if tokens.count >= frequencyContextSize {
434+
tokens[index] = token.item(Int.self)
435+
index = (index + 1) % frequencyContextSize
436+
} else {
437+
tokens.append(token.item(Int.self))
438+
}
439+
}
440+
}
441+
442+
/// Processor that composes penalty processors in Python mlx-lm order.
443+
public struct PenaltyProcessor: LogitProcessor {
444+
var repetitionContext: RepetitionContext?
445+
var presenceContext: PresencePenaltyContext?
446+
var frequencyContext: FrequencyPenaltyContext?
447+
448+
public init(
449+
repetitionContext: RepetitionContext?,
450+
presenceContext: PresencePenaltyContext?,
451+
frequencyContext: FrequencyPenaltyContext?
452+
) {
453+
self.repetitionContext = repetitionContext
454+
self.presenceContext = presenceContext
455+
self.frequencyContext = frequencyContext
456+
}
457+
458+
mutating public func prompt(_ prompt: MLXArray) {
459+
repetitionContext?.prompt(prompt)
460+
presenceContext?.prompt(prompt)
461+
frequencyContext?.prompt(prompt)
462+
}
463+
464+
public func process(logits: MLXArray) -> MLXArray {
465+
var logits = logits
466+
logits = repetitionContext?.process(logits: logits) ?? logits
467+
logits = presenceContext?.process(logits: logits) ?? logits
468+
logits = frequencyContext?.process(logits: logits) ?? logits
469+
return logits
470+
}
471+
472+
mutating public func didSample(token: MLXArray) {
473+
repetitionContext?.didSample(token: token)
474+
presenceContext?.didSample(token: token)
475+
frequencyContext?.didSample(token: token)
476+
}
477+
}
478+
247479
/// Generator of tokens.
248480
///
249481
/// This is typically used via a call to ``generate(input:cache:parameters:context:)`` returning `AsyncStream<Generation>`.

0 commit comments

Comments
 (0)