@@ -151,14 +151,14 @@ extension SwiftEmbedder: VecturaEmbedder {
151151
152152 // Ensure model is loaded
153153 try await ensureModelLoaded ( )
154- let staticTruncateDimension = try staticEmbeddingsTruncateDimension ( )
155154
156155 let dim : Int
157156 if let model2vec = model2vecModel {
158157 // Note: 'dimienstion' is a typo in the upstream swift-embeddings library
159158 // See: swift-embeddings/Sources/Embeddings/Model2Vec/Model2VecModel.swift
160159 dim = model2vec. model. dimienstion
161160 } else if let staticEmbeddings = staticEmbeddingsModel {
161+ let staticTruncateDimension = try staticEmbeddingsTruncateDimension ( )
162162 dim = try Self . resolvedStaticEmbeddingDimension (
163163 baseDimension: staticEmbeddings. model. dimension,
164164 truncateDimension: staticTruncateDimension
@@ -220,12 +220,12 @@ extension SwiftEmbedder: VecturaEmbedder {
220220 /// - Throws: An error if embedding generation fails.
221221 public func embed( texts: [ String ] ) async throws -> [ [ Float ] ] {
222222 try await ensureModelLoaded ( )
223- let staticTruncateDimension = try staticEmbeddingsTruncateDimension ( )
224223
225224 let embeddingsTensor : MLTensor
226225 if let model2vec = model2vecModel {
227226 embeddingsTensor = try model2vec. batchEncode ( texts)
228227 } else if let staticEmbeddings = staticEmbeddingsModel {
228+ let staticTruncateDimension = try staticEmbeddingsTruncateDimension ( )
229229 embeddingsTensor = try staticEmbeddings. batchEncode (
230230 texts,
231231 normalize: true ,
@@ -264,12 +264,12 @@ extension SwiftEmbedder: VecturaEmbedder {
264264 /// - Throws: An error if embedding generation fails.
265265 public func embed( text: String ) async throws -> [ Float ] {
266266 try await ensureModelLoaded ( )
267- let staticTruncateDimension = try staticEmbeddingsTruncateDimension ( )
268267
269268 let embeddingTensor : MLTensor
270269 if let model2vec = model2vecModel {
271270 embeddingTensor = try model2vec. encode ( text)
272271 } else if let staticEmbeddings = staticEmbeddingsModel {
272+ let staticTruncateDimension = try staticEmbeddingsTruncateDimension ( )
273273 embeddingTensor = try staticEmbeddings. encode (
274274 text,
275275 normalize: true ,
@@ -334,15 +334,40 @@ extension Bert {
334334 case . id( let modelId, _) :
335335 do {
336336 return try await loadModelBundle ( from: modelId)
337+ } catch let cancellationError as CancellationError {
338+ throw cancellationError
337339 } catch {
340+ let originalError = error
341+ guard isKeyMappingError ( originalError) else {
342+ throw originalError
343+ }
344+
338345 // Some BERT checkpoints (for example, google-bert/bert-base-uncased)
339346 // require alternative key mapping.
340- return try await loadModelBundle ( from: modelId, loadConfig: . googleBert)
347+ do {
348+ return try await loadModelBundle ( from: modelId, loadConfig: . googleBert)
349+ } catch let cancellationError as CancellationError {
350+ throw cancellationError
351+ } catch {
352+ throw originalError
353+ }
341354 }
342355 case . folder( let url, _) :
343356 return try await loadModelBundle ( from: url)
344357 }
345358 }
359+
360+ private static func isKeyMappingError( _ error: Error ) -> Bool {
361+ let description = String ( describing: error) . lowercased ( )
362+ let localizedDescription = ( error as NSError ) . localizedDescription. lowercased ( )
363+ let combinedDescription = " \( description) \( localizedDescription) "
364+ return combinedDescription. contains ( " key mapping " ) ||
365+ combinedDescription. contains ( " key-mapping " ) ||
366+ combinedDescription. contains ( " missing key " ) ||
367+ combinedDescription. contains ( " unexpected key " ) ||
368+ combinedDescription. contains ( " state_dict " ) ||
369+ combinedDescription. contains ( " state dict " )
370+ }
346371}
347372
348373@available ( macOS 15 . 0 , iOS 18 . 0 , tvOS 18 . 0 , visionOS 2 . 0 , watchOS 11 . 0 , * )
0 commit comments