Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit bd3f264

Browse files
Revert "Remove deprecated 'Dataset` (#1115)" (#1116)
This reverts commit 3f5d896.
1 parent 3f5d896 commit bd3f264

File tree

4 files changed

+286
-40
lines changed

4 files changed

+286
-40
lines changed

Sources/TensorFlow/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ add_library(TensorFlow SHARED
6060

6161
Operators/Basic.swift
6262
Operators/Comparison.swift
63+
Operators/Dataset.swift
6364
Operators/Image.swift
6465
Operators/LinearAlgebra.swift
6566
Operators/Math.swift
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
/// The default graph seed.
16+
///
17+
/// - Note: See TensorFlow's `python.framework.random_seed.DEFAULT_GRAPH_SEED`.
18+
@available(*, deprecated, message: "Graph-level tracing will be removed in S4TF v0.10")
19+
@usableFromInline let _defaultGraphSeed: Int64 = 87_654_321
20+
21+
/// Returns the local seeds an operation should use given an op-specific seed.
22+
///
23+
/// Given operation-specific seed, `seed`, this helper function returns two seeds derived from
24+
/// graph-level and op-level seeds. Many random operations internally use the two seeds to allow
25+
/// user to change the seed globally for a graph, or for only specific operations.
26+
///
27+
/// - Note: See TensorFlow's `python.framework.random_seed.get_seed`.
28+
///
29+
// TODO: There's no support for TF's "global seed" yet, so we always use the default graph seed as
30+
// the first seed. Need to investigate the best way to model TF's "global seed".
31+
@available(*, deprecated, message: "Graph-level tracing will be removed in S4TF v0.10")
32+
@usableFromInline
33+
func _tensorSeeds(_ seed: Tensor<Int64>) -> (Tensor<Int64>, Tensor<Int64>) {
34+
return (Tensor(_defaultGraphSeed, on: .defaultTFEager), seed)
35+
}
36+
37+
//===------------------------------------------------------------------------------------------===//
38+
// Single Value Dataset
39+
//===------------------------------------------------------------------------------------------===//
40+
41+
/// Represents a potentially large set of elements.
42+
///
43+
/// A `Dataset` can be used to represent an input pipeline as a collection of element tensors.
44+
@available(
45+
*, deprecated,
46+
message:
47+
"""
48+
Datasets will be removed in S4TF v0.10. Please use the new Batches API instead.
49+
"""
50+
)
51+
@frozen
52+
public struct Dataset<Element: TensorGroup> {
53+
public let _handle: VariantHandle
54+
55+
@inlinable
56+
public init(_handle: VariantHandle) {
57+
self._handle = _handle
58+
}
59+
}
60+
61+
@available(*, deprecated)
62+
extension Dataset {
63+
@inlinable
64+
public init(randomSeed: Int64) {
65+
let (seed1, seed2) = _tensorSeeds(Tensor(randomSeed, on: .defaultTFEager))
66+
self.init(
67+
_handle: _Raw.experimentalRandomDataset(
68+
seed: seed1,
69+
seed2: seed2,
70+
outputTypes: Element._typeList,
71+
outputShapes: Element._unknownShapeList))
72+
}
73+
}
74+
75+
@available(*, deprecated)
76+
extension Dataset {
77+
/// Creates a dataset from a batch of elements as a tensor.
78+
@inlinable
79+
public init(elements: Element) {
80+
self.init(
81+
_handle: _Raw.tensorSliceDataset(
82+
components: [elements],
83+
outputShapes: Element._unknownShapeList))
84+
}
85+
}
86+
87+
@available(*, deprecated)
88+
extension Dataset: Sequence {
89+
public typealias Iterator = DatasetIterator<Element>
90+
91+
/// Returns an iterator over the elements of this dataset.
92+
@inlinable
93+
public func makeIterator() -> DatasetIterator<Element> {
94+
let resource = _Raw.anonymousIterator(
95+
outputTypes: Element._typeList,
96+
outputShapes: Element._unknownShapeList)
97+
_Raw.makeIterator(dataset: _handle, iterator: resource)
98+
return DatasetIterator(_handle: resource)
99+
}
100+
}
101+
102+
@available(*, deprecated)
103+
extension Dataset {
104+
// Note that this Dataset API implementation uses an experimental tracing feature, which is not
105+
// robust and does not have great diagnostics yet.
106+
@inlinable
107+
public func map<ResultElement: TensorGroup>(
108+
_ transform: (Element) -> ResultElement
109+
) -> Dataset<ResultElement> {
110+
return Dataset<ResultElement>(
111+
_handle: _Raw.mapDataset(
112+
inputDataset: _handle,
113+
otherArguments: Tensor<Int32>(0, on: .defaultTFEager),
114+
f: transform,
115+
outputTypes: ResultElement._typeList,
116+
outputShapes: ResultElement._unknownShapeList,
117+
useInterOpParallelism: true,
118+
preserveCardinality: false))
119+
}
120+
121+
@inlinable
122+
public func map<ResultElement: TensorGroup>(
123+
parallelCallCount: Int,
124+
_ transform: (Element) -> ResultElement
125+
) -> Dataset<ResultElement> {
126+
return Dataset<ResultElement>(
127+
_handle: _Raw.parallelMapDataset(
128+
inputDataset: _handle,
129+
otherArguments: Tensor<Int32>(0, on: .defaultTFEager),
130+
numParallelCalls: Tensor<Int32>(Int32(parallelCallCount), on: .defaultTFEager),
131+
f: transform,
132+
outputTypes: ResultElement._typeList,
133+
outputShapes: ResultElement._unknownShapeList,
134+
useInterOpParallelism: true,
135+
sloppy: false,
136+
preserveCardinality: false))
137+
}
138+
139+
@inlinable
140+
public func filter(_ isIncluded: (Element) -> Tensor<Bool>) -> Dataset {
141+
return Dataset(
142+
_handle: _Raw.filterDataset(
143+
inputDataset: _handle,
144+
otherArguments: Tensor<Int32>(0, on: .defaultTFEager),
145+
predicate: isIncluded,
146+
outputTypes: Element._typeList,
147+
outputShapes: Element._unknownShapeList))
148+
}
149+
}
150+
151+
@available(*, deprecated)
152+
extension Dataset {
153+
@inlinable
154+
public func prefetched(count: Int) -> Dataset {
155+
return Dataset(
156+
_handle: _Raw.prefetchDataset(
157+
inputDataset: _handle,
158+
bufferSize: Tensor(Int64(count), on: .defaultTFEager),
159+
outputTypes: Element._typeList,
160+
outputShapes: Element._unknownShapeList))
161+
}
162+
163+
@inlinable
164+
public func shuffled(
165+
sampleCount: Int,
166+
randomSeed: Int64,
167+
reshuffleForEachIterator: Bool = true
168+
) -> Dataset {
169+
let (seed1, seed2) = _tensorSeeds(Tensor(randomSeed, on: .defaultTFEager))
170+
return Dataset(
171+
_handle: _Raw.shuffleDataset(
172+
inputDataset: _handle,
173+
bufferSize: Tensor(Int64(sampleCount), on: .defaultTFEager),
174+
seed: seed1,
175+
seed2: seed2,
176+
reshuffleEachIteration: reshuffleForEachIterator,
177+
outputTypes: Element._typeList,
178+
outputShapes: Element._unknownShapeList))
179+
}
180+
181+
@inlinable
182+
public func batched(_ batchSize: Int) -> Dataset {
183+
return Dataset(
184+
_handle: _Raw.batchDataset(
185+
inputDataset: _handle,
186+
batchSize: Tensor(Int64(batchSize), on: .defaultTFEager),
187+
outputTypes: Element._typeList,
188+
outputShapes: Element._unknownShapeList))
189+
}
190+
191+
@inlinable
192+
public func repeated(count: Int? = nil) -> Dataset {
193+
return Dataset(
194+
_handle: _Raw.repeatDataset(
195+
inputDataset: _handle,
196+
count: Tensor(Int64(count ?? -1), on: .defaultTFEager),
197+
outputTypes: Element._typeList,
198+
outputShapes: Element._unknownShapeList))
199+
}
200+
}
201+
202+
/// The type that allows iteration over a dataset's elements.
203+
@available(*, deprecated)
204+
@frozen
205+
public struct DatasetIterator<Element: TensorGroup> {
206+
@usableFromInline let _handle: ResourceHandle
207+
208+
@usableFromInline
209+
internal init(_handle: ResourceHandle) {
210+
self._handle = _handle
211+
}
212+
}
213+
214+
@available(*, deprecated)
215+
extension DatasetIterator: IteratorProtocol {
216+
/// Advances to the next element and returns it, or `nil` if no next element exists.
217+
@inlinable
218+
public mutating func next() -> Element? {
219+
let optional = _Raw.iteratorGetNextAsOptional(
220+
iterator: _handle,
221+
outputTypes: Element._typeList,
222+
outputShapes: Element._unknownShapeList)
223+
guard _Raw.optionalHasValue(optional: optional).scalarized() else {
224+
return nil
225+
}
226+
return _Raw.optionalGetValue(
227+
optional: optional,
228+
outputShapes: Element._unknownShapeList)
229+
}
230+
}
231+
232+
/// A 2-tuple-like struct that conforms to TensorGroup that represents a tuple of 2 types conforming
233+
/// to `TensorGroup`.
234+
@frozen
235+
public struct Zip2TensorGroup<T: TensorGroup, U: TensorGroup>: TensorGroup {
236+
public var first: T
237+
public var second: U
238+
239+
public init(_ first: T, _ second: U) {
240+
self.first = first
241+
self.second = second
242+
}
243+
244+
public static var _typeList: [TensorDataType] { return T._typeList + U._typeList }
245+
246+
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
247+
first = .init(_owning: tensorHandles)
248+
second = .init(_owning: tensorHandles?.advanced(by: Int(T._tensorHandleCount)))
249+
}
250+
251+
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
252+
var ptr = address
253+
first._unpackTensorHandles(into: ptr)
254+
ptr = ptr!.advanced(by: Int(first._tensorHandleCount))
255+
second._unpackTensorHandles(into: ptr)
256+
}
257+
258+
public var _tensorHandles: [_AnyTensorHandle] {
259+
first._tensorHandles + second._tensorHandles
260+
}
261+
262+
public init<C: RandomAccessCollection>(
263+
_handles: C
264+
) where C.Element: _AnyTensorHandle {
265+
let firstStart = _handles.startIndex
266+
let firstEnd = _handles.index(
267+
firstStart, offsetBy: Int(T._tensorHandleCount))
268+
self.first = T.init(_handles: _handles[firstStart..<firstEnd])
269+
self.second = U.init(_handles: _handles[firstEnd..<_handles.endIndex])
270+
}
271+
}
272+
273+
// TODO(SR-9156): This does not work in graph mode.
274+
@available(*, deprecated, message: "Graph-level tracing will be removed in S4TF v0.10")
275+
@inlinable
276+
public func zip<T: TensorGroup, U: TensorGroup>(
277+
_ dataset1: Dataset<T>, _ dataset2: Dataset<U>
278+
) -> Dataset<Zip2TensorGroup<T, U>> {
279+
let handle = _Raw.zipDataset(
280+
inputDatasets: [dataset1._handle, dataset2._handle],
281+
outputTypes: Zip2TensorGroup<T, U>._typeList,
282+
outputShapes: Zip2TensorGroup<T, U>._unknownShapeList)
283+
return Dataset(_handle: handle)
284+
}

Tests/TensorFlowTests/LazyTensorExplicitTraceTests.swift

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -189,43 +189,3 @@ final class LazyTensorExplicitTraceTests: LazyTensorTestCase {
189189
("testRetainsIdenticalOutputs", testRetainsIdenticalOutputs),
190190
]
191191
}
192-
193-
/// A 2-tuple-like struct that conforms to TensorGroup that represents a tuple of 2 types conforming
194-
/// to `TensorGroup`.
195-
fileprivate struct Zip2TensorGroup<T: TensorGroup, U: TensorGroup>: TensorGroup {
196-
var first: T
197-
var second: U
198-
199-
init(_ first: T, _ second: U) {
200-
self.first = first
201-
self.second = second
202-
}
203-
204-
static var _typeList: [TensorDataType] { return T._typeList + U._typeList }
205-
206-
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
207-
first = .init(_owning: tensorHandles)
208-
second = .init(_owning: tensorHandles?.advanced(by: Int(T._tensorHandleCount)))
209-
}
210-
211-
func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
212-
var ptr = address
213-
first._unpackTensorHandles(into: ptr)
214-
ptr = ptr!.advanced(by: Int(first._tensorHandleCount))
215-
second._unpackTensorHandles(into: ptr)
216-
}
217-
218-
var _tensorHandles: [_AnyTensorHandle] {
219-
first._tensorHandles + second._tensorHandles
220-
}
221-
222-
init<C: RandomAccessCollection>(
223-
_handles: C
224-
) where C.Element: _AnyTensorHandle {
225-
let firstStart = _handles.startIndex
226-
let firstEnd = _handles.index(
227-
firstStart, offsetBy: Int(T._tensorHandleCount))
228-
self.first = T.init(_handles: _handles[firstStart..<firstEnd])
229-
self.second = U.init(_handles: _handles[firstEnd..<_handles.endIndex])
230-
}
231-
}

Tests/TensorFlowTests/LazyTensorTraceTests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ final class LazyTensorTraceTests: LazyTensorTestCase {
204204
}
205205

206206
func testTraceWithFunctionAttributes() {
207+
typealias Int32Pair = Zip2TensorGroup<Tensor<Int32>, Tensor<Int32>>
207208
func thenBranch(x: Tensor<Float>) -> Tensor<Float> {
208209
return x + 10.0
209210
}

0 commit comments

Comments
 (0)