@@ -82,32 +82,34 @@ final class LazyTensorEvaluationTests: LazyTensorTestCase {
82
82
}
83
83
84
84
func testNoOutputOperations( ) {
85
- let elements1 : Tensor < Int32 > = [ 0 , 1 , 2 ]
86
- let elements2 : Tensor < Int32 > = [ 10 , 11 , 12 ]
87
- let outputTypes = [ Int32 . tensorFlowDataType, Int32 . tensorFlowDataType]
88
- let outputShapes : [ TensorShape ? ] = [ nil , nil ]
89
- let dataset : VariantHandle = _Raw. tensorSliceDataset (
90
- components: [ elements1, elements2] ,
91
- outputShapes: outputShapes
92
- )
93
- let iterator : ResourceHandle = _Raw. iteratorV2 (
94
- sharedName: " blah " ,
95
- container: " earth " , outputTypes: outputTypes, outputShapes: outputShapes
96
- )
97
- // `dataset` and `iterator` should not be materialized yet.
98
- XCTAssertFalse ( isMaterialized ( dataset. handle) )
99
- XCTAssertFalse ( isMaterialized ( iterator. handle) )
100
- _Raw. makeIterator ( dataset: dataset, iterator: iterator)
101
-
102
- // `dataset` and `iterator` should be materialized now as
103
- // makeIterator executes.
104
- XCTAssertTrue ( isMaterialized ( dataset. handle) )
105
- XCTAssertTrue ( isMaterialized ( iterator. handle) )
106
- let next : SimpleOutput = _Raw. iteratorGetNext (
107
- iterator: iterator, outputShapes: outputShapes
108
- )
109
- XCTAssertEqual ( Tensor ( handle: next. a) . scalarized ( ) , 0 )
110
- XCTAssertEqual ( Tensor ( handle: next. b) . scalarized ( ) , 10 )
85
+ withDevice ( . cpu) {
86
+ let elements1 : Tensor < Int32 > = [ 0 , 1 , 2 ]
87
+ let elements2 : Tensor < Int32 > = [ 10 , 11 , 12 ]
88
+ let outputTypes = [ Int32 . tensorFlowDataType, Int32 . tensorFlowDataType]
89
+ let outputShapes : [ TensorShape ? ] = [ nil , nil ]
90
+ let dataset : VariantHandle = _Raw. tensorSliceDataset (
91
+ components: [ elements1, elements2] ,
92
+ outputShapes: outputShapes
93
+ )
94
+ let iterator : ResourceHandle = _Raw. iteratorV2 (
95
+ sharedName: " blah " ,
96
+ container: " earth " , outputTypes: outputTypes, outputShapes: outputShapes
97
+ )
98
+ // `dataset` and `iterator` should not be materialized yet.
99
+ XCTAssertFalse ( isMaterialized ( dataset. handle) )
100
+ XCTAssertFalse ( isMaterialized ( iterator. handle) )
101
+ _Raw. makeIterator ( dataset: dataset, iterator: iterator)
102
+
103
+ // `dataset` and `iterator` should be materialized now as
104
+ // makeIterator executes.
105
+ XCTAssertTrue ( isMaterialized ( dataset. handle) )
106
+ XCTAssertTrue ( isMaterialized ( iterator. handle) )
107
+ let next : SimpleOutput = _Raw. iteratorGetNext (
108
+ iterator: iterator, outputShapes: outputShapes
109
+ )
110
+ XCTAssertEqual ( Tensor ( handle: next. a) . scalarized ( ) , 0 )
111
+ XCTAssertEqual ( Tensor ( handle: next. b) . scalarized ( ) , 10 )
112
+ }
111
113
}
112
114
113
115
private func isMaterialized< T: TensorFlowScalar > ( _ input: Tensor < T > ) -> Bool {
0 commit comments