Skip to content

Commit 0b2beed

Browse files
authored
Merge pull request #72 from rryam/cursor/embedder-validation-and-errors-9508
Embedder validation and errors
2 parents 2976473 + 654d998 commit 0b2beed

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

Sources/VecturaKit/Embedder/SwiftEmbedder.swift

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,14 @@ extension SwiftEmbedder: VecturaEmbedder {
151151

152152
// Ensure model is loaded
153153
try await ensureModelLoaded()
154-
let staticTruncateDimension = try staticEmbeddingsTruncateDimension()
155154

156155
let dim: Int
157156
if let model2vec = model2vecModel {
158157
// Note: 'dimienstion' is a typo in the upstream swift-embeddings library
159158
// See: swift-embeddings/Sources/Embeddings/Model2Vec/Model2VecModel.swift
160159
dim = model2vec.model.dimienstion
161160
} else if let staticEmbeddings = staticEmbeddingsModel {
161+
let staticTruncateDimension = try staticEmbeddingsTruncateDimension()
162162
dim = try Self.resolvedStaticEmbeddingDimension(
163163
baseDimension: staticEmbeddings.model.dimension,
164164
truncateDimension: staticTruncateDimension
@@ -220,12 +220,12 @@ extension SwiftEmbedder: VecturaEmbedder {
220220
/// - Throws: An error if embedding generation fails.
221221
public func embed(texts: [String]) async throws -> [[Float]] {
222222
try await ensureModelLoaded()
223-
let staticTruncateDimension = try staticEmbeddingsTruncateDimension()
224223

225224
let embeddingsTensor: MLTensor
226225
if let model2vec = model2vecModel {
227226
embeddingsTensor = try model2vec.batchEncode(texts)
228227
} else if let staticEmbeddings = staticEmbeddingsModel {
228+
let staticTruncateDimension = try staticEmbeddingsTruncateDimension()
229229
embeddingsTensor = try staticEmbeddings.batchEncode(
230230
texts,
231231
normalize: true,
@@ -264,12 +264,12 @@ extension SwiftEmbedder: VecturaEmbedder {
264264
/// - Throws: An error if embedding generation fails.
265265
public func embed(text: String) async throws -> [Float] {
266266
try await ensureModelLoaded()
267-
let staticTruncateDimension = try staticEmbeddingsTruncateDimension()
268267

269268
let embeddingTensor: MLTensor
270269
if let model2vec = model2vecModel {
271270
embeddingTensor = try model2vec.encode(text)
272271
} else if let staticEmbeddings = staticEmbeddingsModel {
272+
let staticTruncateDimension = try staticEmbeddingsTruncateDimension()
273273
embeddingTensor = try staticEmbeddings.encode(
274274
text,
275275
normalize: true,
@@ -334,15 +334,40 @@ extension Bert {
334334
case .id(let modelId, _):
335335
do {
336336
return try await loadModelBundle(from: modelId)
337+
} catch let cancellationError as CancellationError {
338+
throw cancellationError
337339
} catch {
340+
let originalError = error
341+
guard isKeyMappingError(originalError) else {
342+
throw originalError
343+
}
344+
338345
// Some BERT checkpoints (for example, google-bert/bert-base-uncased)
339346
// require alternative key mapping.
340-
return try await loadModelBundle(from: modelId, loadConfig: .googleBert)
347+
do {
348+
return try await loadModelBundle(from: modelId, loadConfig: .googleBert)
349+
} catch let cancellationError as CancellationError {
350+
throw cancellationError
351+
} catch {
352+
throw originalError
353+
}
341354
}
342355
case .folder(let url, _):
343356
return try await loadModelBundle(from: url)
344357
}
345358
}
359+
360+
private static func isKeyMappingError(_ error: Error) -> Bool {
361+
let description = String(describing: error).lowercased()
362+
let localizedDescription = (error as NSError).localizedDescription.lowercased()
363+
let combinedDescription = "\(description) \(localizedDescription)"
364+
return combinedDescription.contains("key mapping") ||
365+
combinedDescription.contains("key-mapping") ||
366+
combinedDescription.contains("missing key") ||
367+
combinedDescription.contains("unexpected key") ||
368+
combinedDescription.contains("state_dict") ||
369+
combinedDescription.contains("state dict")
370+
}
346371
}
347372

348373
@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *)

0 commit comments

Comments
 (0)