Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 2042ed3

Browse files
authored
Add an xrt device and move the ComputationClient api there. (#1106)
1 parent e69350f commit 2042ed3

File tree

10 files changed

+731
-235
lines changed

10 files changed

+731
-235
lines changed

Sources/CX10/xla_tensor_wrapper.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,13 @@ OpaqueXLATensor* copyTensorAndMakeResident(enum XLATensorScalarType type,
138138
const size_t* shape, size_t rank,
139139
const struct CDevice cdevice,
140140
bool to_reduced_precision) {
141+
auto device = ConvertDevice(cdevice);
141142
if (to_reduced_precision && XLATensorScalarType_Float == type) {
142143
const float* float_buffer = reinterpret_cast<const float*>(value);
143144
auto non_owned_buffer =
144145
std::make_unique<at::NonOwnedAnyScalarBuffer<float>>(
145146
float_buffer, num_entries * sizeof(float));
146147
std::vector<int64_t> dims(shape, shape + rank);
147-
auto device = ConvertDevice(cdevice);
148148
auto dest_shape = swift_xla::MakeArrayShapeFromDimensions(
149149
XlaHelpers::I64List(dims), /*dynamic_dimensions=*/{},
150150
xla::PrimitiveType::BF16, device.hw_type);
@@ -153,8 +153,8 @@ OpaqueXLATensor* copyTensorAndMakeResident(enum XLATensorScalarType type,
153153
return new swift_xla::XLATensor(
154154
swift_xla::XLATensor::Create(xla_data, at::ScalarType::Float));
155155
}
156-
if (XLATensorScalarType_Float == type && xla::ComputationClient::IsLocal()) {
157-
auto device = ConvertDevice(cdevice);
156+
auto* device_ptr = xla::GetX10Device(device);
157+
if (XLATensorScalarType_Float == type && device_ptr->IsLocal()) {
158158
std::vector<xla::int64> dims(shape, shape + rank);
159159
auto dest_shape = swift_xla::MakeArrayShapeFromDimensions(
160160
dims, /*dynamic_dimensions=*/{}, xla::PrimitiveType::F32,

Sources/x10/xla_client/BUILD

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ cc_library(
2424
"computation_client.cc",
2525
"device.cc",
2626
"env_vars.cc",
27+
"local_device.cc",
2728
"mesh_service.cc",
2829
"metrics.cc",
2930
"metrics_reader.cc",
@@ -47,6 +48,7 @@ cc_library(
4748
"debug_macros.h",
4849
"device.h",
4950
"env_vars.h",
51+
"local_device.h",
5052
"mesh_service.h",
5153
"metrics.h",
5254
"metrics_reader.h",
@@ -83,8 +85,10 @@ cc_library(
8385
"//tensorflow/compiler/xla:util",
8486
"//tensorflow/compiler/xla:xla_data_proto_cc",
8587
"//tensorflow/compiler/xla/client:xla_computation",
88+
"//tensorflow/compiler/xla/client:client_library",
8689
"//tensorflow/compiler/xla/service:hlo",
8790
"//tensorflow/compiler/xla/service:hlo_proto_cc",
91+
"//tensorflow/compiler/xla/service:platform_util",
8892
"//tensorflow/compiler/xrt:xrt_proto_cc",
8993
"//tensorflow/compiler/xrt:xrt_server",
9094
"//tensorflow/compiler/xrt:xrt_utils",
@@ -98,10 +102,13 @@ cc_library(
98102
"//tensorflow/core/kernels:conv_ops",
99103
"//tensorflow/core/kernels:data_flow",
100104
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
105+
"//tensorflow/core/profiler/lib:traceme",
101106
"@com_google_absl//absl/memory",
102107
"@com_google_absl//absl/numeric:int128",
103108
"@com_google_absl//absl/strings",
109+
"@com_google_absl//absl/synchronization",
104110
"@com_google_absl//absl/container:node_hash_map",
111+
"@com_google_absl//absl/container:node_hash_set",
105112
"@com_google_absl//absl/types:optional",
106113
"@com_google_absl//absl/types:span",
107114
],

Sources/x10/xla_client/computation_client.cc

Lines changed: 7 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ std::shared_ptr<ComputationClient::Computation> ComputationClient::Compile(
3838
std::vector<CompileInstance> instances;
3939
instances.emplace_back(std::move(computation), output_shape);
4040
std::vector<std::shared_ptr<Computation>> results =
41-
Compile(compilation_device, devices, std::move(instances));
41+
GetX10Device(compilation_device)->Compile(devices, std::move(instances));
4242
return std::move(results[0]);
4343
}
4444

@@ -183,10 +183,8 @@ metrics::Metric* ComputationClient::OutboundDataMetric() {
183183
return metric;
184184
}
185185

186-
ComputationClient::DataPtr ComputationClient::TransferToServer(
187-
xla::BorrowingLiteral literal, const xla::Shape& dest_shape,
188-
const std::string& device) {
189-
TF_LOG(FATAL) << "Only supported for LocalClient";
186+
int32_t ComputationClient::Device::mesh_id() const {
187+
TF_LOG(FATAL) << "Unsupported";
190188
}
191189

192190
std::vector<std::string> ComputationClient::GetAllDevices() const {
@@ -225,50 +223,16 @@ ComputationClient::Device* GetX10Device(swift_xla::Device device_id) {
225223
std::vector<Literal> ComputationClient::TransferFromServer(
226224
absl::Span<const DataPtr> handles) {
227225
if (handles.empty()) return {};
228-
ComputationClient* client = handles[0]->device()->computation_client();
226+
TransferManager* transfer = handles[0]->device()->GetTransferManager();
229227
for (auto& handle : handles) {
230-
XLA_CHECK_EQ(client, handle->device()->computation_client());
228+
XLA_CHECK_EQ(transfer, handle->device()->GetTransferManager());
231229
}
232-
return client->TransferFromServerImpl(handles);
233-
}
234-
235-
std::vector<ComputationClient::ComputationPtr>
236-
ComputationClient::Device::Compile(const std::vector<std::string>& devices,
237-
std::vector<CompileInstance> instances) {
238-
return client_->Compile(name_, devices, std::move(instances));
230+
return transfer->TransferFromServerImpl(handles);
239231
}
240232

241233
ComputationClient::DataPtr ComputationClient::Device::TransferToServer(
242234
xla::BorrowingLiteral literal, const xla::Shape& dest_shape) {
243-
return client_->TransferToServer(std::move(literal), dest_shape, name_);
244-
}
245-
246-
std::vector<ComputationClient::DataPtr>
247-
ComputationClient::Device::TransferToServer(
248-
absl::Span<const TensorSource> tensors) {
249-
return client_->TransferToServer(tensors);
250-
}
251-
252-
std::vector<ComputationClient::DataPtr>
253-
ComputationClient::Device::ExecuteChained(
254-
absl::Span<const ComputationClient::ExecuteChainedOp> ops) {
255-
return client_->ExecuteChained(ops, name_);
256-
}
257-
258-
std::string ComputationClient::Device::ResourceDomain() const {
259-
return client_->GetResourceDomain(name_);
260-
}
261-
262-
ComputationClient::DataPtr ComputationClient::Device::CreateDataPlaceholder(
263-
Shape shape) const {
264-
return client_->CreateDataPlaceholder(name_, std::move(shape));
265-
}
266-
267-
std::vector<ComputationClient::DataPtr>
268-
ComputationClient::Device::ExecuteComputation(
269-
const Computation& computation, absl::Span<const DataPtr> arguments,
270-
const ExecuteComputationOptions& options) {
271-
return client_->ExecuteComputation(computation, arguments, name_, options);
235+
TF_LOG(FATAL) << "Only supported for LocalClient";
272236
}
273237

274238
std::map<std::string, Metric> ComputationClient::ReadMetrics() {
@@ -299,8 +263,4 @@ std::vector<std::string> ComputationClient::AllDevices() {
299263
return Get()->GetAllDevices();
300264
}
301265

302-
std::vector<std::string> ComputationClient::LocalDevices() {
303-
return Get()->GetLocalDevices();
304-
}
305-
306266
} // namespace xla

Sources/x10/xla_client/computation_client.h

Lines changed: 30 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -47,42 +47,58 @@ class ComputationClient {
4747
using DataPtr = std::shared_ptr<Data>;
4848
using ComputationPtr = std::shared_ptr<Computation>;
4949

50+
class TransferManager {
51+
public:
52+
virtual ~TransferManager() {}
53+
54+
virtual std::vector<Literal> TransferFromServerImpl(
55+
absl::Span<const DataPtr> handles) = 0;
56+
};
57+
5058
class Device {
5159
public:
5260
virtual ~Device() {}
5361

5462
const std::string& name() const { return name_; }
55-
ComputationClient* computation_client() const { return client_; }
63+
virtual int32_t mesh_id() const;
5664
const swift_xla::Device& device_id() const { return device_id_; }
57-
explicit Device(std::string name, ComputationClient* client)
58-
: name_(name), client_(client) {
65+
66+
virtual TransferManager* GetTransferManager() const = 0;
67+
explicit Device(std::string name) : name_(name) {
5968
device_id_ = swift_xla::Device(name_);
6069
}
6170

6271
virtual std::vector<ComputationPtr> Compile(
6372
const std::vector<std::string>& devices,
64-
std::vector<CompileInstance> instances);
73+
std::vector<CompileInstance> instances) = 0;
6574

75+
// Transfers local tensor values to the TPU servers and fetches the handles.
6676
virtual std::vector<DataPtr> TransferToServer(
67-
absl::Span<const TensorSource> tensors);
77+
absl::Span<const TensorSource> tensors) = 0;
6878

79+
// Copies a single tensor in the form of a xla::BorrowingLiteral async to
80+
// the TPU. The literal is copied to a temporary buffer and then copied
81+
// async as per the semantics of TransferLiteralToDeviceAsync. The next
82+
// computation that is scheduled will wait for this transfer to complete
83+
// before running.
6984
virtual DataPtr TransferToServer(xla::BorrowingLiteral literal,
7085
const xla::Shape& dest_shape);
7186

7287
virtual std::vector<DataPtr> ExecuteChained(
73-
absl::Span<const ExecuteChainedOp> ops);
88+
absl::Span<const ExecuteChainedOp> ops) = 0;
7489

75-
virtual std::string ResourceDomain() const;
90+
virtual std::string ResourceDomain() const = 0;
7691

77-
virtual DataPtr CreateDataPlaceholder(Shape shape) const;
92+
virtual DataPtr CreateDataPlaceholder(Shape shape) = 0;
7893

7994
virtual std::vector<DataPtr> ExecuteComputation(
8095
const Computation& computation, absl::Span<const DataPtr> arguments,
81-
const ExecuteComputationOptions& options);
96+
const ExecuteComputationOptions& options) = 0;
97+
98+
virtual bool IsLocal() { return false; }
8299

83100
private:
84101
std::string name_;
85-
ComputationClient* client_;
86102
swift_xla::Device device_id_;
87103
};
88104
class Data {
@@ -152,13 +168,10 @@ class ComputationClient {
152168
using PopulateFn = std::function<void(const TensorSource&, void*, size_t)>;
153169

154170
TensorSource() = default;
155-
TensorSource(Shape shape, std::string device, PopulateFn populate_fn)
156-
: shape(std::move(shape)),
157-
device(std::move(device)),
158-
populate_fn(std::move(populate_fn)) {}
171+
TensorSource(Shape shape, PopulateFn populate_fn)
172+
: shape(std::move(shape)), populate_fn(std::move(populate_fn)) {}
159173

160174
Shape shape;
161-
std::string device;
162175
PopulateFn populate_fn;
163176
};
164177

@@ -210,85 +223,13 @@ class ComputationClient {
210223

211224
static std::unique_ptr<ComputationClient> Create();
212225

213-
static bool IsLocal();
214-
215226
virtual ~ComputationClient() {}
216227

217-
// Creates a Data object with no actual device handle in it. The device handle
218-
// will be populated in an asynchrounous fashion.
219-
virtual DataPtr CreateDataPlaceholder(std::string device, Shape shape) = 0;
220-
221-
// Transfers local tensor values to the TPU servers and fetches the handles.
222-
virtual DataPtr TransferToServer(xla::BorrowingLiteral literal,
223-
const xla::Shape& dest_shape,
224-
const std::string& device);
225-
226-
// Transfers local tensor values to the TPU servers and fetches the handles.
227-
virtual std::vector<DataPtr> TransferToServer(
228-
absl::Span<const TensorSource> tensors) = 0;
229-
230228
// Reads the tensor literal values stored at TPU server sites, behind the
231229
// supplied handles.
232230
static std::vector<Literal> TransferFromServer(
233231
absl::Span<const DataPtr> handles);
234232

235-
std::vector<ComputationPtr> Compile(std::vector<CompileInstance> instances);
236-
237-
// Executes computation with arguments and returns the result.
238-
// The passed device must match the common device of the arguments Data.
239-
// If options.explode_tuple is true, the output tuple will be decomposed into
240-
// its single elements.
241-
virtual std::vector<DataPtr> ExecuteComputation(
242-
const Computation& computation, absl::Span<const DataPtr> arguments,
243-
const std::string& device, const ExecuteComputationOptions& options) = 0;
244-
245-
// Executes the computation in replicated mode.
246-
// The size of the arguments vector is the number of replicas to execute,
247-
// and it must match the size of the computation.devices() as well as the
248-
// devices passed as argument. The destination devices for each replicated
249-
// computation come from the devices the Data objects are stored into, which
250-
// must match the devices argument. Within arguments[i], every Data
251-
// object must be coming from the same device. Returns a vector (of the same
252-
// size of the arguments vector) with the results of the parallel execution.
253-
// The result[i], a vector itself, will be the result of the computation fed
254-
// with arguments[i]. If options.explode_tuple is true, the output tuples will
255-
// be decomposed into their single elements.
256-
virtual std::vector<std::vector<DataPtr>> ExecuteReplicated(
257-
const Computation& computation,
258-
const std::vector<std::vector<DataPtr>>& arguments,
259-
absl::Span<const std::string> devices,
260-
const ExecuteReplicatedOptions& options) = 0;
261-
262-
// Executes the computations in parallel. Each computation must target a
263-
// different device, and the the common device of arguments[i] must match
264-
// devices[i]. The computations[i] computation is fed with arguments[i]
265-
// arguments.
266-
// Returns a vector of vectors of device side Data object, with result[i]
267-
// being the return value of computations[i]. If options.explode_tuple is
268-
// true, the output tuples will be decomposed into their single elements.
269-
virtual std::vector<std::vector<DataPtr>> ExecuteParallel(
270-
absl::Span<const Computation* const> computations,
271-
const std::vector<std::vector<DataPtr>>& arguments,
272-
absl::Span<const std::string> devices,
273-
const ExecuteParallelOptions& options) = 0;
274-
275-
// Executes a serie of operations, whose results are input of other
276-
// operations. The ops is a valid post-order for the execution, which means
277-
// that the inputs of op at index I, will have to be coming from ops at index
278-
// lower than I. It returns a vector of device data shared pointers, one for
279-
// every ExecuteChainedOp marked with is_result=true, in the order they appear
280-
// within the ops post-order.
281-
virtual std::vector<DataPtr> ExecuteChained(
282-
absl::Span<const ExecuteChainedOp> ops, const std::string& device) = 0;
283-
284-
virtual std::vector<std::vector<DataPtr>> DeconstructTuple(
285-
absl::Span<const DataPtr> tuples) = 0;
286-
287-
// Returns a unique string which identifies the resource domain of a given
288-
// device. Within a resource domain, handles to device memory or compiled
289-
// computations can be used for all devices part of such domain.
290-
virtual std::string GetResourceDomain(const std::string& device) const = 0;
291-
292233
virtual std::string GetDefaultDevice() const = 0;
293234
static Device* DefaultDevice();
294235

@@ -297,11 +238,6 @@ class ComputationClient {
297238
virtual swift_xla::Device GetDefaultDeviceStruct() const = 0;
298239
static swift_xla::Device DefaultDeviceStruct();
299240

300-
virtual size_t GetNumDevices() const = 0;
301-
302-
virtual std::vector<std::string> GetLocalDevices() const = 0;
303-
static std::vector<std::string> LocalDevices();
304-
305241
std::vector<std::string> GetAllDevices() const;
306242
static std::vector<std::string> AllDevices();
307243

@@ -336,7 +272,6 @@ class ComputationClient {
336272
// after the last ':' character of the device string.
337273
static int64 GetDeviceOrdinal(const std::string& device);
338274

339-
protected:
340275
// Metrics common to all client intrfaces.
341276
static metrics::Metric* TransferToServerMetric();
342277
static metrics::Metric* TransferToServerTransformMetric();
@@ -357,19 +292,14 @@ class ComputationClient {
357292
static metrics::Metric* ReleaseCompileHandlesTimeMetric();
358293
static metrics::Metric* InboundDataMetric();
359294
static metrics::Metric* OutboundDataMetric();
295+
296+
protected:
360297
void AddDevice(std::unique_ptr<Device> device);
361298

362299
// Returns the ComputationClient singleton.
363300
static ComputationClient* Get();
364301

365302
private:
366-
virtual std::vector<Literal> TransferFromServerImpl(
367-
absl::Span<const DataPtr> handles) = 0;
368-
// Compiles a set of computations.
369-
virtual std::vector<ComputationPtr> Compile(
370-
const std::string& device, const std::vector<std::string>& devices,
371-
std::vector<CompileInstance> instances) = 0;
372-
373303
std::vector<Device*> devices_;
374304
std::vector<std::unique_ptr<Device>> devices_owned_;
375305
absl::node_hash_map<std::string, Device*> devices_by_name_;

0 commit comments

Comments
 (0)