Skip to content

Commit 76ae65c

Browse files
Fix threshold values to avoid race condition (#220)
fix: pass threshold values Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
1 parent f8bb54e commit 76ae65c

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

Sources/YOLO/YOLO.swift

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ import UIKit
1919
public class YOLO: @unchecked Sendable {
2020
var predictor: Predictor!
2121

22+
private var pendingNumItems: Int?
23+
private var pendingConfidence: Double?
24+
private var pendingIou: Double?
25+
2226
/// Initialize YOLO with remote URL for automatic download and caching
2327
public init(url: URL, task: YOLOTask, completion: @escaping (Result<YOLO, Error>) -> Void) {
2428
let downloader = YOLOModelDownloader()
@@ -73,6 +77,12 @@ public class YOLO: @unchecked Sendable {
7377
switch result {
7478
case .success(let predictor):
7579
self.predictor = predictor
80+
self.pendingNumItems.map { predictor.setNumItemsThreshold(numItems: $0) }
81+
self.pendingConfidence.map { predictor.setConfidenceThreshold(confidence: $0) }
82+
self.pendingIou.map { predictor.setIouThreshold(iou: $0) }
83+
self.pendingNumItems = nil
84+
self.pendingConfidence = nil
85+
self.pendingIou = nil
7686
completion?(.success(self))
7787
case .failure(let error):
7888
print("Failed to load model with error: \(error)")
@@ -100,13 +110,14 @@ public class YOLO: @unchecked Sendable {
100110
/// Sets the maximum number of detection items to include in results.
101111
/// - Parameter numItems: The maximum number of items to include (default is 30).
102112
public func setNumItemsThreshold(_ numItems: Int) {
113+
pendingNumItems = numItems
103114
(predictor as? BasePredictor)?.setNumItemsThreshold(numItems: numItems)
104115
}
105116

106117
/// Gets the current maximum number of detection items.
107118
/// - Returns: The current threshold value, or nil if not applicable.
108119
public func getNumItemsThreshold() -> Int? {
109-
(predictor as? BasePredictor)?.numItemsThreshold
120+
(predictor as? BasePredictor)?.numItemsThreshold ?? pendingNumItems
110121
}
111122

112123
/// Sets the confidence threshold for filtering results.
@@ -116,13 +127,14 @@ public class YOLO: @unchecked Sendable {
116127
print("Warning: Confidence threshold should be between 0.0 and 1.0")
117128
return
118129
}
130+
pendingConfidence = confidence
119131
(predictor as? BasePredictor)?.setConfidenceThreshold(confidence: confidence)
120132
}
121133

122134
/// Gets the current confidence threshold.
123-
/// - Returns: The current confidence threshold value, or nil if not applicable.
135+
/// - Returns: The current threshold value, or nil if not applicable.
124136
public func getConfidenceThreshold() -> Double? {
125-
(predictor as? BasePredictor)?.confidenceThreshold
137+
(predictor as? BasePredictor)?.confidenceThreshold ?? pendingConfidence
126138
}
127139

128140
/// Sets the IoU (Intersection over Union) threshold for non-maximum suppression.
@@ -132,13 +144,14 @@ public class YOLO: @unchecked Sendable {
132144
print("Warning: IoU threshold should be between 0.0 and 1.0")
133145
return
134146
}
147+
pendingIou = iou
135148
(predictor as? BasePredictor)?.setIouThreshold(iou: iou)
136149
}
137150

138151
/// Gets the current IoU threshold.
139-
/// - Returns: The current IoU threshold value, or nil if not applicable.
152+
/// - Returns: The current threshold value, or nil if not applicable.
140153
public func getIouThreshold() -> Double? {
141-
(predictor as? BasePredictor)?.iouThreshold
154+
(predictor as? BasePredictor)?.iouThreshold ?? pendingIou
142155
}
143156

144157
/// Sets all thresholds at once.

0 commit comments

Comments
 (0)