@@ -26,25 +26,82 @@ import org.tensorflow.ndarray.Shaped
2626 */
2727public val Operand <* >.shape: Shape
2828 get() = this .shape()
29+
30+
31+ public fun interface ShapeErrorLazyMessage {
32+ public fun message (actual : Shape , required : Shape ): String
33+ }
34+
35+ @PublishedApi
36+ internal val defaultShapeErrorMessage: ShapeErrorLazyMessage = ShapeErrorLazyMessage { actual, required ->
37+ " Shape $actual is not compatible with the required shape $required "
38+ }
2939
3040/* *
3141 * Require the [Shaped] object have a certain shape.
3242 *
33- * Throws [IllegalStateException] on failure.
43+ * @throws AssertionError if the shapes are not compatible
3444 */
35- public fun <T : Shaped > T.requireShape (shape : Shape ): T = apply {
36- check(this .shape().isCompatibleWith(shape)) {
37- " Shape ${this .shape()} is not compatible with the required shape $shape "
38- }
45+ public inline fun <T : Shaped > T.assertShape (
46+ requiredShape : Shape ,
47+ exception : ShapeErrorLazyMessage = defaultShapeErrorMessage
48+ ): T = apply {
49+ val actual = this .shape()
50+ assert (actual.isCompatibleWith(requiredShape)) { exception.message(actual, requiredShape) }
3951}
4052
4153/* *
4254 * Require the [Shaped] object have a certain shape.
4355 *
44- * Throws [IllegalStateException] on failure.
56+ * @throws AssertionError if the shapes are not compatible
4557 */
46- public fun <T : Shaped > T.requireShape (vararg shape : Long ): T = apply {
47- check(this .shape().isCompatibleWith(Shape .of(* shape))) {
48- " Shape ${this .shape()} is not compatible with the required shape $shape "
49- }
58+ public inline fun <T : Shaped > T.assertShape (
59+ vararg shape : Long ,
60+ exception : ShapeErrorLazyMessage = defaultShapeErrorMessage
61+ ): T = checkShape(Shape .of(* shape), exception)
62+
63+ /* *
64+ * Require the [Shaped] object have a certain shape.
65+ *
66+ * @throws IllegalArgumentException if the shapes are not compatible
67+ */
68+ public inline fun <T : Shaped > T.requireShape (
69+ requiredShape : Shape ,
70+ exception : ShapeErrorLazyMessage = defaultShapeErrorMessage
71+ ): T = apply {
72+ val actual = this .shape()
73+ require(actual.isCompatibleWith(requiredShape)) { exception.message(actual, requiredShape) }
5074}
75+
76+ /* *
77+ * Require the [Shaped] object have a certain shape.
78+ *
79+ * @throws IllegalArgumentException if the shapes are not compatible
80+ */
81+ public inline fun <T : Shaped > T.requireShape (
82+ vararg shape : Long ,
83+ exception : ShapeErrorLazyMessage = defaultShapeErrorMessage
84+ ): T = checkShape(Shape .of(* shape), exception)
85+
86+ /* *
87+ * Require the [Shaped] object have a certain shape.
88+ *
89+ * @throws IllegalStateException if the shapes are not compatible
90+ */
91+ public inline fun <T : Shaped > T.checkShape (
92+ requiredShape : Shape ,
93+ exception : ShapeErrorLazyMessage = defaultShapeErrorMessage
94+ ): T = apply {
95+ val actual = this .shape()
96+ check(actual.isCompatibleWith(requiredShape)) { exception.message(actual, requiredShape) }
97+ }
98+
99+ /* *
100+ * Require the [Shaped] object have a certain shape.
101+ *
102+ * @throws IllegalStateException if the shapes are not compatible
103+ */
104+ public inline fun <T : Shaped > T.checkShape (
105+ vararg shape : Long ,
106+ exception : ShapeErrorLazyMessage = defaultShapeErrorMessage
107+ ): T = checkShape(Shape .of(* shape), exception)
0 commit comments