|
8 | 8 | import numbers |
9 | 9 | import itertools |
10 | 10 | import multiprocessing |
11 | | -import threading |
12 | 11 | import queue |
13 | 12 | from collections import namedtuple |
14 | 13 | from dataclasses import dataclass |
15 | 14 | import sys |
16 | 15 | import traceback |
17 | 16 |
|
18 | | - |
19 | 17 | def default_convert(data): |
20 | 18 | data_type = type(data) |
21 | 19 | if isinstance(data, np.ndarray): |
@@ -264,7 +262,6 @@ def __init__(self, loader): |
264 | 262 | self._persistent_workers = loader.persistent_workers |
265 | 263 | self._time_out = loader.time_out |
266 | 264 | self._sampler_iter = iter(self._index_sampler) |
267 | | - # self._pin_memory = loader.pin_memory |
268 | 265 | self._num_yielded = 0 |
269 | 266 |
|
270 | 267 | def __iter__(self): |
@@ -321,7 +318,6 @@ def __init__(self, loader): |
321 | 318 | self._worker_result_queue = multiprocessing.Queue() |
322 | 319 | self._worker_done_event = multiprocessing.Event() |
323 | 320 | self._worker_pids_set = False |
324 | | - self._shutdown = False |
325 | 321 |
|
326 | 322 | self._index_queues = [] |
327 | 323 | self._workers = [] |
@@ -357,6 +353,7 @@ def _reset(self, loader, first_iter=False): |
357 | 353 | while resume_iteration_cnt > 0: |
358 | 354 | return_idx, return_data = self._get_data() |
359 | 355 | if isinstance(return_idx, _ResumeIteration): |
| 356 | + assert return_data is None |
360 | 357 | resume_iteration_cnt -= 1 |
361 | 358 | for _ in range(self._prefetch_factor * self._num_workers): |
362 | 359 |
|
@@ -469,26 +466,19 @@ def _shutdown_workers(self): |
469 | 466 | if not self._shutdown: |
470 | 467 | self._shutdown = True |
471 | 468 | try: |
472 | | - if hasattr(self, '_pin_memory_thread'): |
473 | | - self._pin_memory_thread_done_event.set() |
474 | | - self._worker_result_queue.put((None, None)) |
475 | | - self._pin_memory_thread.join() |
476 | | - self._worker_result_queue.cancel_join_thread() |
477 | | - self._worker_result_queue.close() |
478 | | - |
479 | 469 | self._worker_done_event.set() |
480 | 470 | for worker_id in range(len(self._workers)): |
481 | 471 | if self._persistent_workers or self._workers_status[worker_id]: |
482 | 472 | self._mark_worker_as_unavailable(worker_id, shutdown=True) |
483 | 473 | for w in self._workers: |
484 | 474 | w.join(timeout=5.0) |
485 | | - if w.is_alive(): |
486 | | - w.terminate() |
487 | 475 | for q in self._index_queues: |
488 | 476 | q.cancel_join_thread() |
489 | 477 | q.close() |
490 | 478 | finally: |
491 | | - pass |
| 479 | + for w in self._workers: |
| 480 | + if w.is_alive(): |
| 481 | + w.terminate() |
492 | 482 |
|
493 | 483 | def __del__(self): |
494 | 484 | self._shutdown_workers() |
@@ -565,7 +555,15 @@ def _worker_loop( |
565 | 555 | try: |
566 | 556 | data = fetcher.fetch(index) |
567 | 557 | except Exception as e: |
568 | | - data = ExceptionWrapper(where="in DataLoader worker process {}".format(worker_id)) |
| 558 | + if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iter: |
| 559 | + data = _IterableDatasetStopIteration(worker_id) |
| 560 | + iteration_end = True |
| 561 | + else: |
| 562 | + # It is important that we don't store exc_info in a variable. |
| 563 | + # `ExceptionWrapper` does the correct thing. |
| 564 | + # See NOTE [ Python Traceback Reference Cycle Problem ] |
| 565 | + data = ExceptionWrapper( |
| 566 | + where="in DataLoader worker process {}".format(worker_id)) |
569 | 567 | data_queue.put((idx, data)) |
570 | 568 | del data, idx, index, r |
571 | 569 | except KeyboardInterrupt: |
|
0 commit comments