@@ -338,3 +338,65 @@ extension Array: TensorArrayProtocol where Element: TensorGroup {
338
338
}
339
339
}
340
340
}
341
+
342
+ #if TENSORFLOW_USE_STANDARD_TOOLCHAIN
343
+ @_spi ( Reflection) import Swift
344
+
345
+ func reflectionInit< T> ( type: T . Type , body: ( inout T , PartialKeyPath < T > ) -> Void ) -> T {
346
+ let x = UnsafeMutablePointer< T> . allocate( capacity: 1 )
347
+ defer { x. deallocate ( ) }
348
+ if !_forEachFieldWithKeyPath( of: type) { name, kp in
349
+ body ( & x. pointee, kp)
350
+ return true
351
+ } {
352
+ fatalError ( " Cannot initialize \( T . self) because of unknown fields. " )
353
+ }
354
+ return x. move ( )
355
+ }
356
+
357
+ extension TensorGroup {
358
+ public static var _typeList : [ TensorDataType ] {
359
+ var out = [ TensorDataType] ( )
360
+ if !( _forEachFieldWithKeyPath ( of: Self . self) { name, kp in
361
+ guard let valueType = type ( of: kp) . valueType as? TensorGroup . Type else { return false }
362
+ out += valueType. _typeList
363
+ return true
364
+ } ) {
365
+ fatalError ( " \( Self . self) does not have children that conform to TensorGroup. " )
366
+ }
367
+ return out
368
+ }
369
+ public static func initialize< Root> (
370
+ _ base: inout Root , _ kp: PartialKeyPath < Root > ,
371
+ _owning tensorHandles: UnsafePointer < CTensorHandle > ?
372
+ ) {
373
+ guard let kp = kp as? WritableKeyPath < Root , Self > else {
374
+ fatalError ( " \( kp) is not \( WritableKeyPath < Root , Self > . self) " )
375
+ }
376
+ withUnsafeMutablePointer ( to: & base[ keyPath: kp] ) { v in
377
+ v. initialize ( to: . init( _owning: tensorHandles) )
378
+ }
379
+ }
380
+ public init ( _owning tensorHandles: UnsafePointer < CTensorHandle > ? ) {
381
+ var i = 0
382
+ self = reflectionInit ( type: Self . self) { base, kp in
383
+ guard let valueType = type ( of: kp) . valueType as? TensorGroup . Type else {
384
+ fatalError ( " \( type ( of: kp) . valueType) does not conform to TensorGroup " )
385
+ }
386
+ valueType. initialize ( & base, kp, _owning: tensorHandles? . advanced ( by: i) )
387
+ i += Int ( valueType. _tensorHandleCount)
388
+ }
389
+ }
390
+ public func _unpackTensorHandles( into address: UnsafeMutablePointer < CTensorHandle > ? ) {
391
+ var i = 0
392
+ if !_forEachFieldWithKeyPath( of: Self . self) { name, kp in
393
+ guard let x = self [ keyPath: kp] as? TensorGroup else { return false }
394
+ x. _unpackTensorHandles ( into: address? . advanced ( by: i) )
395
+ i += Int ( type ( of: x) . _tensorHandleCount)
396
+ return true
397
+ } {
398
+ fatalError ( " Cannot unpack \( Self . self) because of non-TensorGroup fields. " )
399
+ }
400
+ }
401
+ }
402
+ #endif
0 commit comments