Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 5d57293

Browse files
authored
Implement TensorGroup based on reflection. (#1143)
1 parent 86f09eb commit 5d57293

File tree

3 files changed

+66
-4
lines changed

3 files changed

+66
-4
lines changed

Sources/TensorFlow/Core/TensorGroup.swift

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,65 @@ extension Array: TensorArrayProtocol where Element: TensorGroup {
338338
}
339339
}
340340
}
341+
342+
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
343+
@_spi(Reflection) import Swift
344+
345+
func reflectionInit<T>(type: T.Type, body: (inout T, PartialKeyPath<T>) -> Void) -> T {
346+
let x = UnsafeMutablePointer<T>.allocate(capacity: 1)
347+
defer { x.deallocate() }
348+
if !_forEachFieldWithKeyPath(of: type) { name, kp in
349+
body(&x.pointee, kp)
350+
return true
351+
} {
352+
fatalError("Cannot initialize \(T.self) because of unknown fields.")
353+
}
354+
return x.move()
355+
}
356+
357+
extension TensorGroup {
358+
public static var _typeList: [TensorDataType] {
359+
var out = [TensorDataType]()
360+
if !(_forEachFieldWithKeyPath(of: Self.self) { name, kp in
361+
guard let valueType = type(of: kp).valueType as? TensorGroup.Type else { return false }
362+
out += valueType._typeList
363+
return true
364+
}) {
365+
fatalError("\(Self.self) does not have children that conform to TensorGroup.")
366+
}
367+
return out
368+
}
369+
public static func initialize<Root>(
370+
_ base: inout Root, _ kp: PartialKeyPath<Root>,
371+
_owning tensorHandles: UnsafePointer<CTensorHandle>?
372+
) {
373+
guard let kp = kp as? WritableKeyPath<Root, Self> else {
374+
fatalError("\(kp) is not \(WritableKeyPath<Root, Self>.self)")
375+
}
376+
withUnsafeMutablePointer(to: &base[keyPath: kp]) { v in
377+
v.initialize(to: .init(_owning: tensorHandles))
378+
}
379+
}
380+
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
381+
var i = 0
382+
self = reflectionInit(type: Self.self) { base, kp in
383+
guard let valueType = type(of: kp).valueType as? TensorGroup.Type else {
384+
fatalError("\(type(of: kp).valueType) does not conform to TensorGroup")
385+
}
386+
valueType.initialize(&base, kp, _owning: tensorHandles?.advanced(by: i))
387+
i += Int(valueType._tensorHandleCount)
388+
}
389+
}
390+
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
391+
var i = 0
392+
if !_forEachFieldWithKeyPath(of: Self.self) { name, kp in
393+
guard let x = self[keyPath: kp] as? TensorGroup else { return false }
394+
x._unpackTensorHandles(into: address?.advanced(by: i))
395+
i += Int(type(of: x)._tensorHandleCount)
396+
return true
397+
} {
398+
fatalError("Cannot unpack \(Self.self) because of non-TensorGroup fields.")
399+
}
400+
}
401+
}
402+
#endif

Tests/TensorFlowTests/LazyTensorEvaluationTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ final class LazyTensorEvaluationTests: LazyTensorTestCase {
7777
}
7878

7979
struct SimpleOutput: TensorGroup {
80-
let a: TensorHandle<Int32>
81-
let b: TensorHandle<Int32>
80+
var a: TensorHandle<Int32>
81+
var b: TensorHandle<Int32>
8282
}
8383

8484
func testNoOutputOperations() {

Tests/TensorFlowTests/TensorGroupTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ struct Mixed: TensorGroup, Equatable {
2727
// Mutable.
2828
var float: Tensor<Float>
2929
// Immutable.
30-
let int: Tensor<Int32>
30+
var int: Tensor<Int32>
3131
}
3232

3333
struct Nested: TensorGroup, Equatable {
3434
// Immutable.
35-
let simple: Simple
35+
var simple: Simple
3636
// Mutable.
3737
var mixed: Mixed
3838
}

0 commit comments

Comments
 (0)