@@ -47,42 +47,58 @@ class ComputationClient {
47
47
using DataPtr = std::shared_ptr<Data>;
48
48
using ComputationPtr = std::shared_ptr<Computation>;
49
49
50
+ class TransferManager {
51
+ public:
52
+ virtual ~TransferManager () {}
53
+
54
+ virtual std::vector<Literal> TransferFromServerImpl (
55
+ absl::Span<const DataPtr> handles) = 0;
56
+ };
57
+
50
58
class Device {
51
59
public:
52
60
virtual ~Device () {}
53
61
54
62
const std::string& name () const { return name_; }
55
- ComputationClient* computation_client () const { return client_; }
63
+ virtual int32_t mesh_id () const ;
56
64
const swift_xla::Device& device_id () const { return device_id_; }
57
- explicit Device (std::string name, ComputationClient* client)
58
- : name_(name), client_(client) {
65
+
66
+ virtual TransferManager* GetTransferManager () const = 0;
67
+ explicit Device (std::string name) : name_(name) {
59
68
device_id_ = swift_xla::Device (name_);
60
69
}
61
70
62
71
virtual std::vector<ComputationPtr> Compile (
63
72
const std::vector<std::string>& devices,
64
- std::vector<CompileInstance> instances);
73
+ std::vector<CompileInstance> instances) = 0 ;
65
74
75
+ // Transfers local tensor values to the TPU servers and fetches the handles.
66
76
virtual std::vector<DataPtr> TransferToServer (
67
- absl::Span<const TensorSource> tensors);
77
+ absl::Span<const TensorSource> tensors) = 0 ;
68
78
79
+ // Copies a single tensor in the form of a xla::BorrowingLiteral async to
80
+ // the TPU. The literal is copied to a temporary buffer and then copied
81
+ // async as per the semantics of TransferLiteralToDeviceAsync. The next
82
+ // computation that is scheduled will wait for this transfer to complete
83
+ // before running.
69
84
virtual DataPtr TransferToServer (xla::BorrowingLiteral literal,
70
85
const xla::Shape& dest_shape);
71
86
72
87
virtual std::vector<DataPtr> ExecuteChained (
73
- absl::Span<const ExecuteChainedOp> ops);
88
+ absl::Span<const ExecuteChainedOp> ops) = 0 ;
74
89
75
- virtual std::string ResourceDomain () const ;
90
+ virtual std::string ResourceDomain () const = 0 ;
76
91
77
- virtual DataPtr CreateDataPlaceholder (Shape shape) const ;
92
+ virtual DataPtr CreateDataPlaceholder (Shape shape) = 0 ;
78
93
79
94
virtual std::vector<DataPtr> ExecuteComputation (
80
95
const Computation& computation, absl::Span<const DataPtr> arguments,
81
- const ExecuteComputationOptions& options);
96
+ const ExecuteComputationOptions& options) = 0;
97
+
98
+ virtual bool IsLocal () { return false ; }
82
99
83
100
private:
84
101
std::string name_;
85
- ComputationClient* client_;
86
102
swift_xla::Device device_id_;
87
103
};
88
104
class Data {
@@ -152,13 +168,10 @@ class ComputationClient {
152
168
using PopulateFn = std::function<void (const TensorSource&, void *, size_t )>;
153
169
154
170
TensorSource () = default ;
155
- TensorSource (Shape shape, std::string device, PopulateFn populate_fn)
156
- : shape(std::move(shape)),
157
- device (std::move(device)),
158
- populate_fn(std::move(populate_fn)) {}
171
+ TensorSource (Shape shape, PopulateFn populate_fn)
172
+ : shape(std::move(shape)), populate_fn(std::move(populate_fn)) {}
159
173
160
174
Shape shape;
161
- std::string device;
162
175
PopulateFn populate_fn;
163
176
};
164
177
@@ -210,85 +223,13 @@ class ComputationClient {
210
223
211
224
static std::unique_ptr<ComputationClient> Create ();
212
225
213
- static bool IsLocal ();
214
-
215
226
virtual ~ComputationClient () {}
216
227
217
- // Creates a Data object with no actual device handle in it. The device handle
218
- // will be populated in an asynchrounous fashion.
219
- virtual DataPtr CreateDataPlaceholder (std::string device, Shape shape) = 0;
220
-
221
- // Transfers local tensor values to the TPU servers and fetches the handles.
222
- virtual DataPtr TransferToServer (xla::BorrowingLiteral literal,
223
- const xla::Shape& dest_shape,
224
- const std::string& device);
225
-
226
- // Transfers local tensor values to the TPU servers and fetches the handles.
227
- virtual std::vector<DataPtr> TransferToServer (
228
- absl::Span<const TensorSource> tensors) = 0;
229
-
230
228
// Reads the tensor literal values stored at TPU server sites, behind the
231
229
// supplied handles.
232
230
static std::vector<Literal> TransferFromServer (
233
231
absl::Span<const DataPtr> handles);
234
232
235
- std::vector<ComputationPtr> Compile (std::vector<CompileInstance> instances);
236
-
237
- // Executes computation with arguments and returns the result.
238
- // The passed device must match the common device of the arguments Data.
239
- // If options.explode_tuple is true, the output tuple will be decomposed into
240
- // its single elements.
241
- virtual std::vector<DataPtr> ExecuteComputation (
242
- const Computation& computation, absl::Span<const DataPtr> arguments,
243
- const std::string& device, const ExecuteComputationOptions& options) = 0;
244
-
245
- // Executes the computation in replicated mode.
246
- // The size of the arguments vector is the number of replicas to execute,
247
- // and it must match the size of the computation.devices() as well as the
248
- // devices passed as argument. The destination devices for each replicated
249
- // computation come from the devices the Data objects are stored into, which
250
- // must match the devices argument. Within arguments[i], every Data
251
- // object must be coming from the same device. Returns a vector (of the same
252
- // size of the arguments vector) with the results of the parallel execution.
253
- // The result[i], a vector itself, will be the result of the computation fed
254
- // with arguments[i]. If options.explode_tuple is true, the output tuples will
255
- // be decomposed into their single elements.
256
- virtual std::vector<std::vector<DataPtr>> ExecuteReplicated (
257
- const Computation& computation,
258
- const std::vector<std::vector<DataPtr>>& arguments,
259
- absl::Span<const std::string> devices,
260
- const ExecuteReplicatedOptions& options) = 0;
261
-
262
- // Executes the computations in parallel. Each computation must target a
263
- // different device, and the the common device of arguments[i] must match
264
- // devices[i]. The computations[i] computation is fed with arguments[i]
265
- // arguments.
266
- // Returns a vector of vectors of device side Data object, with result[i]
267
- // being the return value of computations[i]. If options.explode_tuple is
268
- // true, the output tuples will be decomposed into their single elements.
269
- virtual std::vector<std::vector<DataPtr>> ExecuteParallel (
270
- absl::Span<const Computation* const > computations,
271
- const std::vector<std::vector<DataPtr>>& arguments,
272
- absl::Span<const std::string> devices,
273
- const ExecuteParallelOptions& options) = 0;
274
-
275
- // Executes a serie of operations, whose results are input of other
276
- // operations. The ops is a valid post-order for the execution, which means
277
- // that the inputs of op at index I, will have to be coming from ops at index
278
- // lower than I. It returns a vector of device data shared pointers, one for
279
- // every ExecuteChainedOp marked with is_result=true, in the order they appear
280
- // within the ops post-order.
281
- virtual std::vector<DataPtr> ExecuteChained (
282
- absl::Span<const ExecuteChainedOp> ops, const std::string& device) = 0;
283
-
284
- virtual std::vector<std::vector<DataPtr>> DeconstructTuple (
285
- absl::Span<const DataPtr> tuples) = 0;
286
-
287
- // Returns a unique string which identifies the resource domain of a given
288
- // device. Within a resource domain, handles to device memory or compiled
289
- // computations can be used for all devices part of such domain.
290
- virtual std::string GetResourceDomain (const std::string& device) const = 0;
291
-
292
233
virtual std::string GetDefaultDevice () const = 0;
293
234
static Device* DefaultDevice ();
294
235
@@ -297,11 +238,6 @@ class ComputationClient {
297
238
virtual swift_xla::Device GetDefaultDeviceStruct () const = 0;
298
239
static swift_xla::Device DefaultDeviceStruct ();
299
240
300
- virtual size_t GetNumDevices () const = 0;
301
-
302
- virtual std::vector<std::string> GetLocalDevices () const = 0;
303
- static std::vector<std::string> LocalDevices ();
304
-
305
241
std::vector<std::string> GetAllDevices () const ;
306
242
static std::vector<std::string> AllDevices ();
307
243
@@ -336,7 +272,6 @@ class ComputationClient {
336
272
// after the last ':' character of the device string.
337
273
static int64 GetDeviceOrdinal (const std::string& device);
338
274
339
- protected:
340
275
// Metrics common to all client intrfaces.
341
276
static metrics::Metric* TransferToServerMetric ();
342
277
static metrics::Metric* TransferToServerTransformMetric ();
@@ -357,19 +292,14 @@ class ComputationClient {
357
292
static metrics::Metric* ReleaseCompileHandlesTimeMetric ();
358
293
static metrics::Metric* InboundDataMetric ();
359
294
static metrics::Metric* OutboundDataMetric ();
295
+
296
+ protected:
360
297
void AddDevice (std::unique_ptr<Device> device);
361
298
362
299
// Returns the ComputationClient singleton.
363
300
static ComputationClient* Get ();
364
301
365
302
private:
366
- virtual std::vector<Literal> TransferFromServerImpl (
367
- absl::Span<const DataPtr> handles) = 0;
368
- // Compiles a set of computations.
369
- virtual std::vector<ComputationPtr> Compile (
370
- const std::string& device, const std::vector<std::string>& devices,
371
- std::vector<CompileInstance> instances) = 0;
372
-
373
303
std::vector<Device*> devices_;
374
304
std::vector<std::unique_ptr<Device>> devices_owned_;
375
305
absl::node_hash_map<std::string, Device*> devices_by_name_;
0 commit comments