Skip to content

Commit eee0f9f

Browse files
committed
Better shape assertions
Signed-off-by: Ryan Nett <[email protected]>
1 parent ebfdd8e commit eee0f9f

File tree

1 file changed

+67
-10
lines changed
  • tensorflow-kotlin-parent/tensorflow-core-kotlin/src/main/kotlin/org/tensorflow

1 file changed

+67
-10
lines changed

tensorflow-kotlin-parent/tensorflow-core-kotlin/src/main/kotlin/org/tensorflow/OperandHelpers.kt

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,82 @@ import org.tensorflow.ndarray.Shaped
2626
*/
2727
public val Operand<*>.shape: Shape
2828
get() = this.shape()
29+
30+
31+
public fun interface ShapeErrorLazyMessage{
32+
public fun message(actual: Shape, required: Shape): String
33+
}
34+
35+
@PublishedApi
36+
internal val defaultShapeErrorMessage: ShapeErrorLazyMessage = ShapeErrorLazyMessage { actual, required ->
37+
"Shape $actual is not compatible with the required shape $required"
38+
}
2939

3040
/**
3141
* Require the [Shaped] object have a certain shape.
3242
*
33-
* Throws [IllegalStateException] on failure.
43+
* @throws AssertionError if the shapes are not compatible
3444
*/
35-
public fun <T : Shaped> T.requireShape(shape: Shape): T = apply {
36-
check(this.shape().isCompatibleWith(shape)) {
37-
"Shape ${this.shape()} is not compatible with the required shape $shape"
38-
}
45+
public inline fun <T : Shaped> T.assertShape(
46+
requiredShape: Shape,
47+
exception: ShapeErrorLazyMessage = defaultShapeErrorMessage
48+
): T = apply {
49+
val actual = this.shape()
50+
assert(actual.isCompatibleWith(requiredShape)) { exception.message(actual, requiredShape) }
3951
}
4052

4153
/**
4254
* Require the [Shaped] object have a certain shape.
4355
*
44-
* Throws [IllegalStateException] on failure.
56+
* @throws AssertionError if the shapes are not compatible
4557
*/
46-
public fun <T : Shaped> T.requireShape(vararg shape: Long): T = apply {
47-
check(this.shape().isCompatibleWith(Shape.of(*shape))) {
48-
"Shape ${this.shape()} is not compatible with the required shape $shape"
49-
}
58+
public inline fun <T : Shaped> T.assertShape(
59+
vararg shape: Long,
60+
exception: ShapeErrorLazyMessage = defaultShapeErrorMessage
61+
): T = checkShape(Shape.of(*shape), exception)
62+
63+
/**
64+
* Require the [Shaped] object have a certain shape.
65+
*
66+
* @throws IllegalArgumentException if the shapes are not compatible
67+
*/
68+
public inline fun <T : Shaped> T.requireShape(
69+
requiredShape: Shape,
70+
exception: ShapeErrorLazyMessage = defaultShapeErrorMessage
71+
): T = apply {
72+
val actual = this.shape()
73+
require(actual.isCompatibleWith(requiredShape)) { exception.message(actual, requiredShape) }
5074
}
75+
76+
/**
77+
* Require the [Shaped] object have a certain shape.
78+
*
79+
* @throws IllegalArgumentException if the shapes are not compatible
80+
*/
81+
public inline fun <T : Shaped> T.requireShape(
82+
vararg shape: Long,
83+
exception: ShapeErrorLazyMessage = defaultShapeErrorMessage
84+
): T = checkShape(Shape.of(*shape), exception)
85+
86+
/**
87+
* Require the [Shaped] object have a certain shape.
88+
*
89+
* @throws IllegalStateException if the shapes are not compatible
90+
*/
91+
public inline fun <T : Shaped> T.checkShape(
92+
requiredShape: Shape,
93+
exception: ShapeErrorLazyMessage = defaultShapeErrorMessage
94+
): T = apply {
95+
val actual = this.shape()
96+
check(actual.isCompatibleWith(requiredShape)) { exception.message(actual, requiredShape) }
97+
}
98+
99+
/**
100+
* Require the [Shaped] object have a certain shape.
101+
*
102+
* @throws IllegalStateException if the shapes are not compatible
103+
*/
104+
public inline fun <T : Shaped> T.checkShape(
105+
vararg shape: Long,
106+
exception: ShapeErrorLazyMessage = defaultShapeErrorMessage
107+
): T = checkShape(Shape.of(*shape), exception)

0 commit comments

Comments
 (0)