Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 47ebf6c

Browse files
author
marcrasi
authored
make it compile without hardcoded tangentvector conformances (#1135)
1 parent fb720c5 commit 47ebf6c

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

Tests/TensorFlowTests/OptimizerTests.swift

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,14 @@ class OptimizerTests: XCTestCase {
130130
}
131131

132132
/// A `Tensor<Float>` wrapper for testing optimizer numerical correctness.
133-
/// - Note: `KeyPathIterable` conformance is needed for `SGD`.
134-
struct NumericalValues: Differentiable & KeyPathIterable {
133+
/// - Note: `Layer` conformance is needed for `SGD`, because it makes the `TangentVector`
134+
/// conform to some required protocols.
135+
struct NumericalValues: Layer {
135136
var value = Tensor<Float>([0, 0, 0])
137+
138+
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
139+
input
140+
}
136141
}
137142

138143
/// Check expected weight and bias after updating `model` with `optimizer` `stepCount` times.

0 commit comments

Comments
 (0)