Skip to content

Commit d531a5b

Browse files
pschuhshawnwang18
authored andcommitted
Add StatusOr to transfer server BulkTransportInterface on the bond id to
forward errors from bond connection failures to the control plane connection. PiperOrigin-RevId: 820783819
1 parent fd7a671 commit d531a5b

File tree

8 files changed

+51
-10
lines changed

8 files changed

+51
-10
lines changed

xla/python/transfer/socket-server.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,14 @@ class SocketServer::SocketNetworkState : public PollEventLoop::Handler {
299299
msg.data = const_cast<void*>(data);
300300
msg.size = size;
301301
msg.on_send = [val = tsl::FormRef(this), offset, req_id, is_largest](
302-
int bond_id, size_t size) {
302+
absl::StatusOr<int> bond_id, size_t size) {
303+
if (!bond_id.ok()) {
304+
val->SendError(req_id, offset, size, is_largest, bond_id.status());
305+
return;
306+
}
303307
SocketTransferRequest response;
304308
auto* packet = response.mutable_packet();
305-
packet->set_bulk_transport_id(bond_id);
309+
packet->set_bulk_transport_id(*bond_id);
306310
packet->set_offset(offset);
307311
packet->set_size(size);
308312
packet->set_req_id(req_id);

xla/python/transfer/socket_bulk_transport.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ class SendConnectionHandler : public PollEventLoop::Handler {
180180
artificial_send_limit_(artificial_send_limit) {}
181181

182182
~SendConnectionHandler() override {
183+
msg_queue_->Poison(absl::InternalError("A send connection has failed."));
183184
#ifdef MSG_ZEROCOPY
184185
table_.ClearAll();
185186
#endif
@@ -356,6 +357,19 @@ std::shared_ptr<SharedSendWorkQueue> SharedSendWorkQueue::Start() {
356357
return result;
357358
}
358359

360+
void SharedSendMsgQueue::Poison(absl::Status s) {
361+
mu_.lock();
362+
poison_status_ = s;
363+
auto work_items = std::move(work_items_);
364+
mu_.unlock();
365+
while (!work_items.empty()) {
366+
auto work = std::move(work_items.front());
367+
work_items.pop_front();
368+
std::move(work.on_send)(s, work.size);
369+
std::move(work.on_done)();
370+
}
371+
}
372+
359373
void SharedSendMsgQueue::ReportReadyToSend(SendConnectionHandler* handler) {
360374
mu_.lock();
361375
if (!work_items_.empty()) {
@@ -375,6 +389,13 @@ void SharedSendMsgQueue::ReportReadyToSend(SendConnectionHandler* handler) {
375389
void SharedSendMsgQueue::ScheduleSendWork(
376390
aux::BulkTransportInterface::SendMessage msg) {
377391
mu_.lock();
392+
if (!poison_status_.ok()) {
393+
auto s = poison_status_;
394+
mu_.unlock();
395+
std::move(msg.on_send)(std::move(s), msg.size);
396+
std::move(msg.on_done)();
397+
return;
398+
}
378399
DCHECK(!shutdown_);
379400
if (work_items_.empty() && !handlers_.empty()) {
380401
auto* handler = handlers_.front();

xla/python/transfer/socket_bulk_transport.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,16 @@ class SharedSendMsgQueue {
129129
std::shared_ptr<SharedSendWorkQueue> work_queue,
130130
size_t artificial_send_limiti = std::numeric_limits<size_t>::max());
131131

132+
void Poison(absl::Status s);
133+
132134
private:
133135
friend class SendConnectionHandler;
134136

135137
void ReportReadyToSend(SendConnectionHandler* handler);
136138

137139
absl::Mutex mu_;
138140
bool shutdown_ = false;
141+
absl::Status poison_status_;
139142
std::deque<SendConnectionHandler*> handlers_;
140143
std::deque<aux::BulkTransportInterface::SendMessage> work_items_;
141144
};

xla/python/transfer/socket_bulk_transport_test.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ TEST(SendQueue, TestZeroCopyQueueCleanRemoteShutdown) {
8585
BulkTransportInterface::SendMessage msg;
8686
msg.data = txt_msg.data();
8787
msg.size = txt_msg.size();
88-
msg.on_send = [](int id, size_t size) {};
88+
msg.on_send = [](absl::StatusOr<int> id, size_t size) {};
8989
msg.on_done = [&notify]() { notify.Notify(); };
9090
msg_queue->ScheduleSendWork(std::move(msg));
9191
notify.WaitForNotification();
@@ -124,7 +124,7 @@ TEST(SendQueue, SendAndRecvQueuesArtificialLimit) {
124124
BulkTransportInterface::SendMessage msg;
125125
msg.data = txt_msg.data();
126126
msg.size = txt_msg.size();
127-
msg.on_send = [](int id, size_t size) {};
127+
msg.on_send = [](absl::StatusOr<int> id, size_t size) {};
128128
msg.on_done = [&mu, &send_count]() {
129129
absl::MutexLock l(mu);
130130
--send_count;
@@ -230,9 +230,9 @@ TEST(SocketBulkTransportFactoryTest, SendAndRecvWithFactory) {
230230
BulkTransportInterface::SendMessage msg;
231231
msg.data = txt_msgs[i].data();
232232
msg.size = txt_msgs[i].size();
233-
msg.on_send = [&, i](int id, size_t size) {
233+
msg.on_send = [&, i](absl::StatusOr<int> id, size_t size) {
234234
absl::MutexLock l(mu);
235-
send_queue.push_back({i, id});
235+
send_queue.push_back({i, id.value()});
236236
};
237237
msg.on_done = [&mu, &send_count]() {
238238
absl::MutexLock l(mu);

xla/python/transfer/streaming.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,10 @@ BulkTransportInterface::SendMessage BulkTransportInterface::MakeMessage(
151151
SendMessage result;
152152
result.data = tmp->data();
153153
result.size = tmp->size();
154-
result.on_send = std::move(on_send);
154+
result.on_send = [on_send = std::move(on_send)](absl::StatusOr<int> bond_id,
155+
size_t size) mutable {
156+
std::move(on_send)(bond_id.value(), size);
157+
};
155158
result.on_done = [tmp = std::move(tmp)]() {};
156159
return result;
157160
}

xla/python/transfer/streaming.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ class BulkTransportInterface {
107107
// There may be some delay between Send() and when the message
108108
// is actually sent. on_send gets called when the message actually
109109
// gets sent.
110-
absl::AnyInvocable<void(int bond_id, size_t size) &&> on_send;
110+
absl::AnyInvocable<void(absl::StatusOr<int> bond_id, size_t size) &&>
111+
on_send;
111112
};
112113

113114
// Schedules a send over a BulkTransportInterface connection.

xla/python/transfer/streaming_ifrt.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ void PremappedCopierState::StartWorkUnlocked(const WorkList& work_list) {
180180
--num_parallel_copies_;
181181
work_item->is_ready = true;
182182
work_item->result_status = s;
183-
FlushReadyWorkItemsInOrder();
183+
if (!currently_flushing_) {
184+
FlushReadyWorkItemsInOrder();
185+
}
184186
work_list2 = FindWorkLocked();
185187
}
186188
StartWorkUnlocked(work_list2);
@@ -194,14 +196,20 @@ void PremappedCopierState::FlushReadyWorkItemsInOrder() {
194196
if (!work_item->is_ready) {
195197
return;
196198
}
199+
if (!work_item->result_status.ok()) {
200+
available_copy_offsets_.push_back(work_item->dest_buffer);
201+
}
202+
currently_flushing_ = true;
203+
mu_.unlock();
197204
if (work_item->result_status.ok()) {
198205
std::move(work_item->on_done)(this, work_item->dest_buffer,
199206
work_item->work);
200207
} else {
201208
std::move(work_item->on_done)(this, work_item->result_status,
202209
work_item->work);
203-
available_copy_offsets_.push_back(work_item->dest_buffer);
204210
}
211+
mu_.lock();
212+
currently_flushing_ = false;
205213
work_queue_.pop_front();
206214
++base_seq_id_;
207215
}

xla/python/transfer/streaming_ifrt.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ class PremappedCopierState
117117
size_t num_parallel_copies_ = 0;
118118
std::deque<WorkQueueItem> work_queue_ ABSL_GUARDED_BY(mu_);
119119
std::shared_ptr<absl::Span<uint8_t>> scratch_;
120+
bool currently_flushing_ ABSL_GUARDED_BY(mu_) = false;
120121
size_t max_num_parallel_copies_;
121122
size_t xfer_size_;
122123
size_t max_copies_;

0 commit comments

Comments
 (0)