@@ -78,12 +78,30 @@ public struct GenerateParameters: Sendable {
7878 /// top p sampling
7979 public var topP : Float
8080
81+ /// top k sampling (0 disables)
82+ public var topK : Int
83+
84+ /// min p sampling threshold relative to the highest probability token (0 disables)
85+ public var minP : Float
86+
8187 /// penalty factor for repeating tokens
8288 public var repetitionPenalty : Float ?
8389
8490 /// number of tokens to consider for repetition penalty
8591 public var repetitionContextSize : Int
8692
93+ /// additive penalty for tokens that appear in recent context
94+ public var presencePenalty : Float ?
95+
96+ /// number of tokens to consider for presence penalty
97+ public var presenceContextSize : Int
98+
99+ /// additive penalty that scales with token frequency in recent context
100+ public var frequencyPenalty : Float ?
101+
102+ /// number of tokens to consider for frequency penalty
103+ public var frequencyContextSize : Int
104+
87105 public init (
88106 maxTokens: Int ? = nil ,
89107 maxKVSize: Int ? = nil ,
@@ -92,8 +110,14 @@ public struct GenerateParameters: Sendable {
92110 quantizedKVStart: Int = 0 ,
93111 temperature: Float = 0.6 ,
94112 topP: Float = 1.0 ,
113+ topK: Int = 0 ,
114+ minP: Float = 0.0 ,
95115 repetitionPenalty: Float ? = nil ,
96116 repetitionContextSize: Int = 20 ,
117+ presencePenalty: Float ? = nil ,
118+ presenceContextSize: Int = 20 ,
119+ frequencyPenalty: Float ? = nil ,
120+ frequencyContextSize: Int = 20 ,
97121 prefillStepSize: Int = 512
98122 ) {
99123 self . maxTokens = maxTokens
@@ -103,28 +127,71 @@ public struct GenerateParameters: Sendable {
103127 self . quantizedKVStart = quantizedKVStart
104128 self . temperature = temperature
105129 self . topP = topP
130+ self . topK = topK
131+ self . minP = minP
106132 self . repetitionPenalty = repetitionPenalty
107133 self . repetitionContextSize = repetitionContextSize
134+ self . presencePenalty = presencePenalty
135+ self . presenceContextSize = presenceContextSize
136+ self . frequencyPenalty = frequencyPenalty
137+ self . frequencyContextSize = frequencyContextSize
108138 self . prefillStepSize = prefillStepSize
109139 }
110140
111141 public func sampler( ) -> LogitSampler {
142+ let usesTopP = topP > 0 && topP < 1
143+ let usesTopK = topK > 0
144+ let usesMinP = minP > 0
145+
112146 if temperature == 0 {
113147 return ArgMaxSampler ( )
114- } else if topP > 0 && topP < 1 {
115- return TopPSampler ( temperature: temperature, topP: topP)
148+ } else if usesTopP || usesTopK || usesMinP {
149+ return TopPSampler ( temperature: temperature, topP: topP, topK : topK , minP : minP )
116150 } else {
117151 return CategoricalSampler ( temperature: temperature)
118152 }
119153 }
120154
121155 public func processor( ) -> LogitProcessor ? {
122- if let repetitionPenalty, repetitionContextSize > 0 {
123- return RepetitionContext (
124- repetitionPenalty: repetitionPenalty, repetitionContextSize: repetitionContextSize)
156+ let repetitionContext : RepetitionContext ?
157+ if let repetitionPenalty, repetitionPenalty != 0 , repetitionContextSize > 0 {
158+ repetitionContext = RepetitionContext (
159+ repetitionPenalty: repetitionPenalty,
160+ repetitionContextSize: repetitionContextSize
161+ )
162+ } else {
163+ repetitionContext = nil
164+ }
165+
166+ let presenceContext : PresencePenaltyContext ?
167+ if let presencePenalty, presencePenalty != 0 , presenceContextSize > 0 {
168+ presenceContext = PresencePenaltyContext (
169+ presencePenalty: presencePenalty,
170+ presenceContextSize: presenceContextSize
171+ )
172+ } else {
173+ presenceContext = nil
174+ }
175+
176+ let frequencyContext : FrequencyPenaltyContext ?
177+ if let frequencyPenalty, frequencyPenalty != 0 , frequencyContextSize > 0 {
178+ frequencyContext = FrequencyPenaltyContext (
179+ frequencyPenalty: frequencyPenalty,
180+ frequencyContextSize: frequencyContextSize
181+ )
125182 } else {
183+ frequencyContext = nil
184+ }
185+
186+ if repetitionContext == nil && presenceContext == nil && frequencyContext == nil {
126187 return nil
127188 }
189+
190+ return PenaltyProcessor (
191+ repetitionContext: repetitionContext,
192+ presenceContext: presenceContext,
193+ frequencyContext: frequencyContext
194+ )
128195 }
129196}
130197
@@ -137,15 +204,24 @@ public struct ArgMaxSampler: LogitSampler {
137204 }
138205}
139206
140- /// Sampler that uses `topP` and `temperature` to sample the logits.
207+ /// Sampler that uses probability filters (`topP`, `topK`, `minP`) and `temperature`
208+ /// to sample the logits.
141209public struct TopPSampler : LogitSampler {
142210 let temp : MLXArray
143- let topP : MLXArray
211+ let topP : MLXArray ?
212+ let topK : Int ?
213+ let minP : MLXArray ?
144214 let randomState : MLXRandom . RandomState
145215
146- public init ( temperature: Float , topP: Float ) {
216+ public init ( temperature: Float , topP: Float = 1.0 , topK : Int = 0 , minP : Float = 0.0 ) {
147217 self . temp = MLXArray ( temperature)
148- self . topP = MLXArray ( topP)
218+ if topP > 0 && topP < 1 {
219+ self . topP = MLXArray ( topP)
220+ } else {
221+ self . topP = nil
222+ }
223+ self . topK = topK > 0 ? topK : nil
224+ self . minP = minP > 0 ? MLXArray ( minP) : nil
149225 self . randomState = MLXRandom . RandomState ( )
150226 }
151227
@@ -156,18 +232,43 @@ public struct TopPSampler: LogitSampler {
156232 }
157233
158234 return withRandomState ( randomState) {
159- let probs = softmax ( logits / temp, axis: - 1 )
235+ // Match mlx-lm Python behavior:
236+ // apply filtering on the base distribution, then apply temperature at sampling time.
237+ let probs = softmax ( logits, axis: - 1 )
160238 let sortedIndices = argSort ( probs, axis: - 1 )
161239
162240 // probs shape is [B,V] and after take it will be [1, B, V], so we squeeze it back to [B, V]
163241 let sortedProbs = take ( probs, sortedIndices, axis: - 1 ) . squeezed ( axis: 0 )
164242
165- let cumulativeProbs = cumsum ( sortedProbs, axis: - 1 )
243+ var filteredProbs = sortedProbs
244+
245+ if let topP {
246+ let cumulativeProbs = cumsum ( sortedProbs, axis: - 1 )
247+ filteredProbs = MLX . where (
248+ cumulativeProbs .> ( 1 - topP) , filteredProbs, zeros ( like: filteredProbs) )
249+ }
250+
251+ if let minP {
252+ let maxProbs = sortedProbs [ 0 ... , - 1 ] . expandedDimensions ( axis: - 1 )
253+ let keepMask = sortedProbs .>= ( maxProbs * minP)
254+ filteredProbs = MLX . where ( keepMask, filteredProbs, zeros ( like: filteredProbs) )
255+ }
256+
257+ if let topK {
258+ let vocabularySize = sortedProbs. dim ( - 1 )
259+ if topK < vocabularySize {
260+ let cutOff = vocabularySize - topK
261+ let sortedPositions = MLXArray ( Array ( 0 ..< vocabularySize) )
262+ let keepMask = sortedPositions .>= cutOff
263+ filteredProbs = MLX . where (
264+ keepMask, filteredProbs, zeros ( like: filteredProbs) )
265+ }
266+ }
166267
167- let topProbs = MLX . where (
168- cumulativeProbs .> ( 1 - topP ) , sortedProbs, zeros ( like : sortedProbs ) )
268+ // Always keep the maximum-probability token so sampling always has a valid candidate.
269+ filteredProbs [ 0 ... , - 1 ] = sortedProbs [ 0 ... , - 1 ]
169270
170- let sortedToken = categorical ( log ( topProbs ) )
271+ let sortedToken = categorical ( log ( filteredProbs ) * ( 1 / temp ) )
171272 return sortedIndices. squeezed ( axis: 0 ) [ sortedToken]
172273 }
173274 }
@@ -244,6 +345,137 @@ public struct RepetitionContext: LogitProcessor {
244345 }
245346}
246347
348+ /// Processor that applies an additive presence penalty to tokens in a recent context window.
349+ public struct PresencePenaltyContext : LogitProcessor {
350+ var tokens = [ Int] ( )
351+ var index = 0
352+
353+ let presencePenalty : Float
354+ let presenceContextSize : Int
355+
356+ public init ( presencePenalty: Float , presenceContextSize: Int ) {
357+ precondition ( presenceContextSize > 0 )
358+ self . presencePenalty = presencePenalty
359+ self . presenceContextSize = presenceContextSize
360+ }
361+
362+ mutating public func prompt( _ prompt: MLXArray ) {
363+ if prompt. shape [ 0 ] <= presenceContextSize {
364+ self . tokens = prompt. asArray ( Int . self)
365+ } else {
366+ self . tokens = prompt [ ( - presenceContextSize) ... ] . asArray ( Int . self)
367+ }
368+ }
369+
370+ public func process( logits: MLXArray ) -> MLXArray {
371+ if tokens. isEmpty {
372+ return logits
373+ }
374+
375+ let uniqueTokens = Array ( Set ( tokens) )
376+ let indices = MLXArray ( uniqueTokens. map { UInt32 ( $0) } )
377+ logits [ 0 ... , indices] = logits [ 0 ... , indices] - presencePenalty
378+ return logits
379+ }
380+
381+ mutating public func didSample( token: MLXArray ) {
382+ if tokens. count >= presenceContextSize {
383+ tokens [ index] = token. item ( Int . self)
384+ index = ( index + 1 ) % presenceContextSize
385+ } else {
386+ tokens. append ( token. item ( Int . self) )
387+ }
388+ }
389+ }
390+
391+ /// Processor that applies an additive frequency penalty to tokens in a recent context window.
392+ public struct FrequencyPenaltyContext : LogitProcessor {
393+ var tokens = [ Int] ( )
394+ var index = 0
395+
396+ let frequencyPenalty : Float
397+ let frequencyContextSize : Int
398+
399+ public init ( frequencyPenalty: Float , frequencyContextSize: Int ) {
400+ precondition ( frequencyContextSize > 0 )
401+ self . frequencyPenalty = frequencyPenalty
402+ self . frequencyContextSize = frequencyContextSize
403+ }
404+
405+ mutating public func prompt( _ prompt: MLXArray ) {
406+ if prompt. shape [ 0 ] <= frequencyContextSize {
407+ self . tokens = prompt. asArray ( Int . self)
408+ } else {
409+ self . tokens = prompt [ ( - frequencyContextSize) ... ] . asArray ( Int . self)
410+ }
411+ }
412+
413+ public func process( logits: MLXArray ) -> MLXArray {
414+ if tokens. isEmpty {
415+ return logits
416+ }
417+
418+ var counts = [ Int: Int] ( )
419+ for token in tokens {
420+ counts [ token, default: 0 ] += 1
421+ }
422+
423+ let orderedTokens = Array ( counts. keys)
424+ let indices = MLXArray ( orderedTokens. map { UInt32 ( $0) } )
425+ let penalties = MLXArray (
426+ orderedTokens. map { frequencyPenalty * Float( counts [ $0] ?? 0 ) }
427+ )
428+ logits [ 0 ... , indices] = logits [ 0 ... , indices] - penalties
429+ return logits
430+ }
431+
432+ mutating public func didSample( token: MLXArray ) {
433+ if tokens. count >= frequencyContextSize {
434+ tokens [ index] = token. item ( Int . self)
435+ index = ( index + 1 ) % frequencyContextSize
436+ } else {
437+ tokens. append ( token. item ( Int . self) )
438+ }
439+ }
440+ }
441+
442+ /// Processor that composes penalty processors in Python mlx-lm order.
443+ public struct PenaltyProcessor : LogitProcessor {
444+ var repetitionContext : RepetitionContext ?
445+ var presenceContext : PresencePenaltyContext ?
446+ var frequencyContext : FrequencyPenaltyContext ?
447+
448+ public init (
449+ repetitionContext: RepetitionContext ? ,
450+ presenceContext: PresencePenaltyContext ? ,
451+ frequencyContext: FrequencyPenaltyContext ?
452+ ) {
453+ self . repetitionContext = repetitionContext
454+ self . presenceContext = presenceContext
455+ self . frequencyContext = frequencyContext
456+ }
457+
458+ mutating public func prompt( _ prompt: MLXArray ) {
459+ repetitionContext? . prompt ( prompt)
460+ presenceContext? . prompt ( prompt)
461+ frequencyContext? . prompt ( prompt)
462+ }
463+
464+ public func process( logits: MLXArray ) -> MLXArray {
465+ var logits = logits
466+ logits = repetitionContext? . process ( logits: logits) ?? logits
467+ logits = presenceContext? . process ( logits: logits) ?? logits
468+ logits = frequencyContext? . process ( logits: logits) ?? logits
469+ return logits
470+ }
471+
472+ mutating public func didSample( token: MLXArray ) {
473+ repetitionContext? . didSample ( token: token)
474+ presenceContext? . didSample ( token: token)
475+ frequencyContext? . didSample ( token: token)
476+ }
477+ }
478+
247479/// Generator of tokens.
248480///
249481/// This is typically used via a call to ``generate(input:cache:parameters:context:)`` returning `AsyncStream<Generation>`.
0 commit comments